home *** CD-ROM | disk | FTP | other *** search
/ Inter.Net 55-1 / Inter.Net 55-1.iso / CBuilder / Setup / BCB / data.z / atlimpl.cpp < prev    next >
Encoding:
C/C++ Source or Header  |  1998-02-09  |  66.5 KB  |  2,853 lines

  1. // This is a part of the Active Template Library.
  2. // Copyright (C) 1996-1997 Microsoft Corporation
  3. // All rights reserved.
  4. //
  5. // This source code is only intended as a supplement to the
  6. // Active Template Library Reference and related
  7. // electronic documentation provided with the library.
  8. // See these sources for detailed information regarding the
  9. // Active Template Library product.
  10.  
  11. #ifndef __ATLBASE_H__
  12.     #error atlimpl.cpp requires atlbase.h to be included first
  13. #endif
  14.  
  15. #if defined(BCC32_COMPAT)
  16. #pragma warn -aus
  17. #endif
  18.  
  19. extern "C" const IID IID_IRegistrar = {0x44EC053B,0x400F,0x11D0,{0x9D,0xCD,0x00,0xA0,0xC9,0x03,0x91,0xD3}};
  20. #ifndef _ATL_DLL_IMPL
  21. extern "C" const CLSID CLSID_Registrar = {0x44EC053A,0x400F,0x11D0,{0x9D,0xCD,0x00,0xA0,0xC9,0x03,0x91,0xD3}};
  22. #endif
  23.  
  24. #include "atlconv.cpp"
  25. #ifdef _DEBUG
  26. #include <stdio.h>
  27. #include <stdarg.h>
  28. #endif
  29.  
  30. #ifndef ATL_NO_NAMESPACE
  31. namespace ATL
  32. {
  33. #endif
  34.  
  35. // used in thread pooling
  36. UINT CComApartment::ATL_CREATE_OBJECT = 0;
  37.  
  38. #ifdef __ATLCOM_H__
  39.  
  40. /////////////////////////////////////////////////////////////////////////////
  41. // AtlReportError
  42.  
  43. HRESULT WINAPI AtlReportError(const CLSID& clsid, UINT nID, const IID& iid,
  44.     HRESULT hRes, HINSTANCE hInst)
  45. {
  46.     return AtlSetErrorInfo(clsid, (LPCOLESTR)MAKEINTRESOURCE(nID), 0, NULL, iid, hRes, hInst);
  47. }
  48.  
  49. HRESULT WINAPI AtlReportError(const CLSID& clsid, UINT nID, DWORD dwHelpID,
  50.     LPCOLESTR lpszHelpFile, const IID& iid, HRESULT hRes, HINSTANCE hInst)
  51. {
  52.     return AtlSetErrorInfo(clsid, (LPCOLESTR)MAKEINTRESOURCE(nID), dwHelpID,
  53.         lpszHelpFile, iid, hRes, hInst);
  54. }
  55.  
  56. #ifndef OLE2ANSI
  57. HRESULT WINAPI AtlReportError(const CLSID& clsid, LPCSTR lpszDesc,
  58.     DWORD dwHelpID, LPCSTR lpszHelpFile, const IID& iid, HRESULT hRes)
  59. {
  60.     _ASSERTE(lpszDesc != NULL);
  61.     USES_CONVERSION;
  62.     return AtlSetErrorInfo(clsid, A2COLE(lpszDesc), dwHelpID, A2CW(lpszHelpFile),
  63.         iid, hRes, NULL);
  64. }
  65.  
  66. HRESULT WINAPI AtlReportError(const CLSID& clsid, LPCSTR lpszDesc,
  67.     const IID& iid, HRESULT hRes)
  68. {
  69.     _ASSERTE(lpszDesc != NULL);
  70.     USES_CONVERSION;
  71.     return AtlSetErrorInfo(clsid, A2COLE(lpszDesc), 0, NULL, iid, hRes, NULL);
  72. }
  73. #endif
  74.  
  75. HRESULT WINAPI AtlReportError(const CLSID& clsid, LPCOLESTR lpszDesc,
  76.     const IID& iid, HRESULT hRes)
  77. {
  78.     return AtlSetErrorInfo(clsid, lpszDesc, 0, NULL, iid, hRes, NULL);
  79. }
  80.  
  81. HRESULT WINAPI AtlReportError(const CLSID& clsid, LPCOLESTR lpszDesc, DWORD dwHelpID,
  82.     LPCOLESTR lpszHelpFile, const IID& iid, HRESULT hRes)
  83. {
  84.     return AtlSetErrorInfo(clsid, lpszDesc, dwHelpID, lpszHelpFile, iid, hRes, NULL);
  85. }
  86.  
  87. #endif //__ATLCOM_H__
  88.  
  89. /////////////////////////////////////////////////////////////////////////////
  90. // CComBSTR
  91.  
  92. CComBSTR& CComBSTR::operator=(const CComBSTR& src)
  93. {
  94.     if (m_str != src.m_str)
  95.     {
  96.         if (m_str)
  97.             ::SysFreeString(m_str);
  98.         m_str = src.Copy();
  99.     }
  100.     return *this;
  101. }
  102.  
  103. CComBSTR& CComBSTR::operator=(LPCOLESTR pSrc)
  104. {
  105.     ::SysFreeString(m_str);
  106.     m_str = ::SysAllocString(pSrc);
  107.     return *this;
  108. }
  109.  
  110. void CComBSTR::Append(LPCOLESTR lpsz, int nLen)
  111. {
  112.     int n1 = Length();
  113.     BSTR b = SysAllocStringLen(NULL, n1+nLen);
  114.     memcpy(b, m_str, n1*sizeof(OLECHAR));
  115.     memcpy(b+n1, lpsz, nLen*sizeof(OLECHAR));
  116.     b[n1+nLen] = NULL;
  117.     SysFreeString(m_str);
  118.     m_str = b;
  119. }
  120.  
  121. #ifndef OLE2ANSI
  122. void CComBSTR::Append(LPCSTR lpsz)
  123. {
  124.     USES_CONVERSION;
  125.     LPCOLESTR lpo = A2COLE(lpsz);
  126.     Append(lpo, ocslen(lpo));
  127. }
  128.  
  129. CComBSTR::CComBSTR(LPCSTR pSrc)
  130. {
  131.     USES_CONVERSION;
  132.     m_str = ::SysAllocString(A2COLE(pSrc));
  133. }
  134.  
  135. CComBSTR::CComBSTR(int nSize, LPCSTR sz)
  136. {
  137.     USES_CONVERSION;
  138.     m_str = ::SysAllocStringLen(A2COLE(sz), nSize);
  139. }
  140.  
  141. CComBSTR& CComBSTR::operator=(LPCSTR pSrc)
  142. {
  143.     USES_CONVERSION;
  144.     ::SysFreeString(m_str);
  145.     m_str = ::SysAllocString(A2COLE(pSrc));
  146.     return *this;
  147. }
  148. #endif
  149.  
  150. HRESULT CComBSTR::ReadFromStream(IStream* pStream)
  151. {
  152.     _ASSERTE(pStream != NULL);
  153.     _ASSERTE(m_str == NULL); // should be empty
  154.     ULONG cb;
  155.     ULONG cbStrLen;
  156.     HRESULT hr = pStream->Read((void*) &cbStrLen, sizeof(cbStrLen), &cb);
  157.     if (FAILED(hr))
  158.         return hr;
  159.     if (cbStrLen != 0)
  160.     {
  161.         //subtract size for terminating NULL which we wrote out
  162.         //since SysAllocStringByteLen overallocates for the NULL
  163.         m_str = SysAllocStringByteLen(NULL, cbStrLen-sizeof(OLECHAR));
  164.         if (m_str == NULL)
  165.             hr = E_OUTOFMEMORY;
  166.         else
  167.             hr = pStream->Read((void*) m_str, cbStrLen, &cb);
  168.     }
  169.     return hr;
  170. }
  171.  
  172. HRESULT CComBSTR::WriteToStream(IStream* pStream)
  173. {
  174.     _ASSERTE(pStream != NULL);
  175.     ULONG cb;
  176.     ULONG cbStrLen = m_str ? SysStringByteLen(m_str)+sizeof(OLECHAR) : 0;
  177.     HRESULT hr = pStream->Write((void*) &cbStrLen, sizeof(cbStrLen), &cb);
  178.     if (FAILED(hr))
  179.         return hr;
  180.     return cbStrLen ? pStream->Write((void*) m_str, cbStrLen, &cb) : S_OK;
  181. }
  182.  
  183. /////////////////////////////////////////////////////////////////////////////
  184. // CComVariant
  185.  
  186. CComVariant& CComVariant::operator=(BSTR bstrSrc)
  187. {
  188.     InternalClear();
  189.     vt = VT_BSTR;
  190.     bstrVal = ::SysAllocString(bstrSrc);
  191.     if (bstrVal == NULL && bstrSrc != NULL)
  192.     {
  193.         vt = VT_ERROR;
  194.         scode = E_OUTOFMEMORY;
  195.     }
  196.     return *this;
  197. }
  198.  
  199. CComVariant& CComVariant::operator=(LPCOLESTR lpszSrc)
  200. {
  201.     InternalClear();
  202.     vt = VT_BSTR;
  203.     bstrVal = ::SysAllocString(lpszSrc);
  204.  
  205.     if (bstrVal == NULL && lpszSrc != NULL)
  206.     {
  207.         vt = VT_ERROR;
  208.         scode = E_OUTOFMEMORY;
  209.     }
  210.     return *this;
  211. }
  212.  
  213. #ifndef OLE2ANSI
  214. CComVariant& CComVariant::operator=(LPCSTR lpszSrc)
  215. {
  216.     USES_CONVERSION;
  217.     InternalClear();
  218.     vt = VT_BSTR;
  219.     bstrVal = ::SysAllocString(A2COLE(lpszSrc));
  220.  
  221.     if (bstrVal == NULL && lpszSrc != NULL)
  222.     {
  223.         vt = VT_ERROR;
  224.         scode = E_OUTOFMEMORY;
  225.     }
  226.     return *this;
  227. }
  228. #endif
  229.  
  230. #if _MSC_VER>1020
  231. CComVariant& CComVariant::operator=(bool bSrc)
  232. {
  233.     if (vt != VT_BOOL)
  234.     {
  235.         InternalClear();
  236.         vt = VT_BOOL;
  237.     }
  238. #pragma warning(disable: 4310) // cast truncates constant value
  239.     boolVal = bSrc ? VARIANT_TRUE : VARIANT_FALSE;
  240. #pragma warning(default: 4310) // cast truncates constant value
  241.     return *this;
  242. }
  243. #endif
  244.  
  245. CComVariant& CComVariant::operator=(int nSrc)
  246. {
  247.     if (vt != VT_I4)
  248.     {
  249.         InternalClear();
  250.         vt = VT_I4;
  251.     }
  252.     lVal = nSrc;
  253.  
  254.     return *this;
  255. }
  256.  
  257. CComVariant& CComVariant::operator=(BYTE nSrc)
  258. {
  259.     if (vt != VT_UI1)
  260.     {
  261.         InternalClear();
  262.         vt = VT_UI1;
  263.     }
  264.     bVal = nSrc;
  265.     return *this;
  266. }
  267.  
  268. CComVariant& CComVariant::operator=(short nSrc)
  269. {
  270.     if (vt != VT_I2)
  271.     {
  272.         InternalClear();
  273.         vt = VT_I2;
  274.     }
  275.     iVal = nSrc;
  276.     return *this;
  277. }
  278.  
  279. CComVariant& CComVariant::operator=(long nSrc)
  280. {
  281.     if (vt != VT_I4)
  282.     {
  283.         InternalClear();
  284.         vt = VT_I4;
  285.     }
  286.     lVal = nSrc;
  287.     return *this;
  288. }
  289.  
  290. CComVariant& CComVariant::operator=(float fltSrc)
  291. {
  292.     if (vt != VT_R4)
  293.     {
  294.         InternalClear();
  295.         vt = VT_R4;
  296.     }
  297.     fltVal = fltSrc;
  298.     return *this;
  299. }
  300.  
  301. CComVariant& CComVariant::operator=(double dblSrc)
  302. {
  303.     if (vt != VT_R8)
  304.     {
  305.         InternalClear();
  306.         vt = VT_R8;
  307.     }
  308.     dblVal = dblSrc;
  309.     return *this;
  310. }
  311.  
  312. CComVariant& CComVariant::operator=(CY cySrc)
  313. {
  314.     if (vt != VT_CY)
  315.     {
  316.         InternalClear();
  317.         vt = VT_CY;
  318.     }
  319. #if !defined(BCC32_COMPAT) || defined(_ANONYMOUS_STRUCT)
  320.     cyVal.Hi = cySrc.Hi;
  321.     cyVal.Lo = cySrc.Lo;
  322. #else
  323.     cyVal.s.Hi = cySrc.s.Hi;
  324.     cyVal.s.Lo = cySrc.s.Lo;
  325. #endif
  326.     return *this;
  327. }
  328.  
  329. CComVariant& CComVariant::operator=(IDispatch* pSrc)
  330. {
  331.     InternalClear();
  332.     vt = VT_DISPATCH;
  333.     pdispVal = pSrc;
  334.     // Need to AddRef as VariantClear will Release
  335.     if (pdispVal != NULL)
  336.         pdispVal->AddRef();
  337.     return *this;
  338. }
  339.  
  340. CComVariant& CComVariant::operator=(IUnknown* pSrc)
  341. {
  342.     InternalClear();
  343.     vt = VT_UNKNOWN;
  344.     punkVal = pSrc;
  345.  
  346.     // Need to AddRef as VariantClear will Release
  347.     if (punkVal != NULL)
  348.         punkVal->AddRef();
  349.     return *this;
  350. }
  351.  
  352. #if _MSC_VER>1020
  353. bool CComVariant::operator==(const VARIANT& varSrc)
  354. {
  355.     if (this == &varSrc)
  356.         return true;
  357.  
  358.     // Variants not equal if types don't match
  359.     if (vt != varSrc.vt)
  360.         return false;
  361.  
  362.     // Check type specific values
  363.     switch (vt)
  364.     {
  365.         case VT_EMPTY:
  366.         case VT_NULL:
  367.             return true;
  368.  
  369.         case VT_BOOL:
  370.             return boolVal == varSrc.boolVal;
  371.  
  372.         case VT_UI1:
  373.             return bVal == varSrc.bVal;
  374.  
  375.         case VT_I2:
  376.             return iVal == varSrc.iVal;
  377.  
  378.         case VT_I4:
  379.             return lVal == varSrc.lVal;
  380.  
  381.         case VT_R4:
  382.             return fltVal == varSrc.fltVal;
  383.  
  384.         case VT_R8:
  385.             return dblVal == varSrc.dblVal;
  386.  
  387.         case VT_BSTR:
  388.             return (::SysStringByteLen(bstrVal) == ::SysStringByteLen(varSrc.bstrVal)) &&
  389.                     (::memcmp(bstrVal, varSrc.bstrVal, ::SysStringByteLen(bstrVal)) == 0);
  390.  
  391.         case VT_ERROR:
  392.             return scode == varSrc.scode;
  393.  
  394.         case VT_DISPATCH:
  395.             return pdispVal == varSrc.pdispVal;
  396.  
  397.         case VT_UNKNOWN:
  398.             return punkVal == varSrc.punkVal;
  399.  
  400.         default:
  401.             _ASSERTE(false);
  402.             // fall through
  403.     }
  404.  
  405.     return false;
  406. }
  407. #else
  408. BOOL CComVariant::operator==(const VARIANT& varSrc)
  409. {
  410.     if (this == &varSrc)
  411.         return TRUE;
  412.  
  413.     // Variants not equal if types don't match
  414.     if (vt != varSrc.vt)
  415.         return FALSE;
  416.  
  417.     // Check type specific values
  418.     switch (vt)
  419.     {
  420.         case VT_EMPTY:
  421.         case VT_NULL:
  422.             return TRUE;
  423.  
  424.         case VT_BOOL:
  425.             return boolVal == varSrc.boolVal;
  426.  
  427.         case VT_UI1:
  428.             return bVal == varSrc.bVal;
  429.  
  430.         case VT_I2:
  431.             return iVal == varSrc.iVal;
  432.  
  433.         case VT_I4:
  434.             return lVal == varSrc.lVal;
  435.  
  436.         case VT_R4:
  437.             return fltVal == varSrc.fltVal;
  438.  
  439.         case VT_R8:
  440.             return dblVal == varSrc.dblVal;
  441.  
  442.         case VT_BSTR:
  443.             return (::SysStringByteLen(bstrVal) == ::SysStringByteLen(varSrc.bstrVal)) &&
  444.                     (::memcmp(bstrVal, varSrc.bstrVal, ::SysStringByteLen(bstrVal)) == 0);
  445.  
  446.         case VT_ERROR:
  447.             return scode == varSrc.scode;
  448.  
  449.         case VT_DISPATCH:
  450.             return pdispVal == varSrc.pdispVal;
  451.  
  452.         case VT_UNKNOWN:
  453.             return punkVal == varSrc.punkVal;
  454.  
  455.         default:
  456.             _ASSERTE(FALSE);
  457.             // fall through
  458.     }
  459.  
  460.     return FALSE;
  461. }
  462. #endif
  463.  
  464. HRESULT CComVariant::Attach(VARIANT* pSrc)
  465. {
  466.     // Clear out the variant
  467.     HRESULT hr = Clear();
  468.     if (!FAILED(hr))
  469.     {
  470.         // Copy the contents and give control to CComVariant
  471.         memcpy(this, pSrc, sizeof(VARIANT));
  472.         VariantInit(pSrc);
  473.         hr = S_OK;
  474.     }
  475.     return hr;
  476. }
  477.  
  478. HRESULT CComVariant::Detach(VARIANT* pDest)
  479. {
  480.     // Clear out the variant
  481.     HRESULT hr = ::VariantClear(pDest);
  482.     if (!FAILED(hr))
  483.     {
  484.         // Copy the contents and remove control from CComVariant
  485.         memcpy(pDest, this, sizeof(VARIANT));
  486.         vt = VT_EMPTY;
  487.         hr = S_OK;
  488.     }
  489.     return hr;
  490. }
  491.  
  492. HRESULT CComVariant::ChangeType(VARTYPE vtNew, const VARIANT* pSrc)
  493. {
  494.     VARIANT* pVar = const_cast<VARIANT*>(pSrc);
  495.     // Convert in place if pSrc is NULL
  496.     if (pVar == NULL)
  497.         pVar = this;
  498.     // Do nothing if doing in place convert and vts not different
  499.     return ::VariantChangeType(this, pVar, 0, vtNew);
  500. }
  501.  
  502. HRESULT CComVariant::InternalClear()
  503. {
  504.     HRESULT hr = Clear();
  505.     _ASSERTE(SUCCEEDED(hr));
  506.     if (FAILED(hr))
  507.     {
  508.         vt = VT_ERROR;
  509.         scode = hr;
  510.     }
  511.     return hr;
  512. }
  513.  
  514. void CComVariant::InternalCopy(const VARIANT* pSrc)
  515. {
  516.     HRESULT hr = Copy(pSrc);
  517.     if (FAILED(hr))
  518.     {
  519.         vt = VT_ERROR;
  520.         scode = hr;
  521.     }
  522. }
  523.  
  524.  
  525. HRESULT CComVariant::WriteToStream(IStream* pStream)
  526. {
  527.     HRESULT hr = pStream->Write(&vt, sizeof(VARTYPE), NULL);
  528.     if (FAILED(hr))
  529.         return hr;
  530.  
  531.     int cbWrite = 0;
  532.     switch (vt)
  533.     {
  534.     case VT_UNKNOWN:
  535.     case VT_DISPATCH:
  536.         {
  537.             CComPtr<IPersistStream> spStream;
  538.             if (punkVal != NULL)
  539.             {
  540.                 hr = punkVal->QueryInterface(IID_IPersistStream, (void**)&spStream);
  541.                 if (FAILED(hr))
  542.                     return hr;
  543.             }
  544.             if (spStream != NULL)
  545.                 return OleSaveToStream(spStream, pStream);
  546.             else
  547.                 return WriteClassStm(pStream, CLSID_NULL);
  548.         }
  549.     case VT_UI1:
  550.     case VT_I1:
  551.         cbWrite = sizeof(BYTE);
  552.         break;
  553.     case VT_I2:
  554.     case VT_UI2:
  555.     case VT_BOOL:
  556.         cbWrite = sizeof(short);
  557.         break;
  558.     case VT_I4:
  559.     case VT_UI4:
  560.     case VT_R4:
  561.     case VT_INT:
  562.     case VT_UINT:
  563.     case VT_ERROR:
  564.         cbWrite = sizeof(long);
  565.         break;
  566.     case VT_R8:
  567.     case VT_CY:
  568.     case VT_DATE:
  569.         cbWrite = sizeof(double);
  570.         break;
  571.     default:
  572.         break;
  573.     }
  574.     if (cbWrite != 0)
  575.         return pStream->Write((void*) &bVal, cbWrite, NULL);
  576.  
  577.     CComBSTR bstrWrite;
  578.     CComVariant varBSTR;
  579.     if (vt != VT_BSTR)
  580.     {
  581.         hr = VariantChangeType(&varBSTR, this, VARIANT_NOVALUEPROP, VT_BSTR);
  582.         if (FAILED(hr))
  583.             return hr;
  584.         bstrWrite = varBSTR.bstrVal;
  585.     }
  586.     else
  587.         bstrWrite = bstrVal;
  588.  
  589.     return bstrWrite.WriteToStream(pStream);
  590. }
  591.  
  592. HRESULT CComVariant::ReadFromStream(IStream* pStream)
  593. {
  594.     _ASSERTE(pStream != NULL);
  595.     HRESULT hr;
  596.     hr = VariantClear(this);
  597.     if (FAILED(hr))
  598.         return hr;
  599.     VARTYPE vtRead;
  600.     hr = pStream->Read(&vtRead, sizeof(VARTYPE), NULL);
  601.     if (FAILED(hr))
  602.         return hr;
  603.  
  604.     vt = vtRead;
  605.     int cbRead = 0;
  606.     switch (vtRead)
  607.     {
  608.     case VT_UNKNOWN:
  609.     case VT_DISPATCH:
  610.         {
  611.             punkVal = NULL;
  612.             hr = OleLoadFromStream(pStream, 
  613.                 (vtRead == VT_UNKNOWN) ? IID_IUnknown : IID_IDispatch, 
  614.                 (void**)&punkVal);
  615.             if (hr == REGDB_E_CLASSNOTREG)
  616.                 hr = S_OK;
  617.             return S_OK;
  618.         }
  619.     case VT_UI1:
  620.     case VT_I1:
  621.         cbRead = sizeof(BYTE);
  622.         break;
  623.     case VT_I2:
  624.     case VT_UI2:
  625.     case VT_BOOL:
  626.         cbRead = sizeof(short);
  627.         break;
  628.     case VT_I4:
  629.     case VT_UI4:
  630.     case VT_R4:
  631.     case VT_INT:
  632.     case VT_UINT:
  633.     case VT_ERROR:
  634.         cbRead = sizeof(long);
  635.         break;
  636.     case VT_R8:
  637.     case VT_CY:
  638.     case VT_DATE:
  639.         cbRead = sizeof(double);
  640.         break;
  641.     default:
  642.         break;
  643.     }
  644.     if (cbRead != 0)
  645.         return pStream->Read((void*) &bVal, cbRead, NULL);
  646.     CComBSTR bstrRead;
  647.  
  648.     hr = bstrRead.ReadFromStream(pStream);
  649.     if (FAILED(hr))
  650.         return hr;
  651.     vt = VT_BSTR;
  652.     bstrVal = bstrRead.Detach();
  653.     if (vtRead != VT_BSTR)
  654.         hr = ChangeType(vtRead);
  655.     return hr;
  656. }
  657.  
  658. #ifdef __ATLCOM_H__
  659.  
  660. /////////////////////////////////////////////////////////////////////////////
  661. // CComTypeInfoHolder
  662.  
  663. void CComTypeInfoHolder::AddRef()
  664. {
  665.     EnterCriticalSection(&_Module.m_csTypeInfoHolder);
  666.     m_dwRef++;
  667.     LeaveCriticalSection(&_Module.m_csTypeInfoHolder);
  668. }
  669.  
  670. void CComTypeInfoHolder::Release()
  671. {
  672.     EnterCriticalSection(&_Module.m_csTypeInfoHolder);
  673.     if (--m_dwRef == 0)
  674.     {
  675.         if (m_pInfo != NULL)
  676.             m_pInfo->Release();
  677.         m_pInfo = NULL;
  678.     }
  679.     LeaveCriticalSection(&_Module.m_csTypeInfoHolder);
  680. }
  681.  
  682. HRESULT CComTypeInfoHolder::GetTI(LCID lcid, ITypeInfo** ppInfo)
  683. {
  684.     //If this assert occurs then most likely didn't initialize properly
  685.     _ASSERTE(m_plibid != NULL && m_pguid != NULL);
  686.     _ASSERTE(ppInfo != NULL);
  687.     *ppInfo = NULL;
  688.  
  689.     HRESULT hRes = E_FAIL;
  690.     EnterCriticalSection(&_Module.m_csTypeInfoHolder);
  691.     if (m_pInfo == NULL)
  692.     {
  693.         ITypeLib* pTypeLib;
  694.         hRes = LoadRegTypeLib(*m_plibid, m_wMajor, m_wMinor, lcid, &pTypeLib);
  695.         if (SUCCEEDED(hRes))
  696.         {
  697.             ITypeInfo* pTypeInfo;
  698.             hRes = pTypeLib->GetTypeInfoOfGuid(*m_pguid, &pTypeInfo);
  699.             if (SUCCEEDED(hRes))
  700.                 m_pInfo = pTypeInfo;
  701.             pTypeLib->Release();
  702.         }
  703.     }
  704.     *ppInfo = m_pInfo;
  705.     if (m_pInfo != NULL)
  706.     {
  707.         m_pInfo->AddRef();
  708.                 // jmt.  Add another to work
  709.                 // around remote datamodule problem.
  710.                 m_pInfo->AddRef();
  711.         hRes = S_OK;
  712.     }
  713.     LeaveCriticalSection(&_Module.m_csTypeInfoHolder);
  714.     return hRes;
  715. }
  716.  
  717. HRESULT CComTypeInfoHolder::GetTypeInfo(UINT /*itinfo*/, LCID lcid,
  718.     ITypeInfo** pptinfo)
  719. {
  720.     HRESULT hRes = E_POINTER;
  721.     if (pptinfo != NULL)
  722.         hRes = GetTI(lcid, pptinfo);
  723.     return hRes;
  724. }
  725.  
  726. HRESULT CComTypeInfoHolder::GetIDsOfNames(REFIID /*riid*/, LPOLESTR* rgszNames,
  727.     UINT cNames, LCID lcid, DISPID* rgdispid)
  728. {
  729.     ITypeInfo* pInfo;
  730.     HRESULT hRes = GetTI(lcid, &pInfo);
  731.     if (pInfo != NULL)
  732.     {
  733.         hRes = pInfo->GetIDsOfNames(rgszNames, cNames, rgdispid);
  734.         pInfo->Release();
  735.     }
  736.     return hRes;
  737. }
  738.  
  739. HRESULT CComTypeInfoHolder::Invoke(IDispatch* p, DISPID dispidMember, REFIID /*riid*/,
  740.     LCID lcid, WORD wFlags, DISPPARAMS* pdispparams, VARIANT* pvarResult,
  741.     EXCEPINFO* pexcepinfo, UINT* puArgErr)
  742. {
  743.     SetErrorInfo(0, NULL);
  744.     ITypeInfo* pInfo;
  745.     HRESULT hRes = GetTI(lcid, &pInfo);
  746.     if (pInfo != NULL)
  747.     {
  748.         hRes = pInfo->Invoke(p, dispidMember, wFlags, pdispparams, pvarResult, pexcepinfo, puArgErr);
  749.         pInfo->Release();
  750.     }
  751.     return hRes;
  752. }
  753.  
  754. /////////////////////////////////////////////////////////////////////////////
  755. // QI implementation
  756.  
  757. #ifdef _ATL_DEBUG_QI
  758. HRESULT WINAPI AtlDumpIID(REFIID iid, LPCTSTR pszClassName, HRESULT hr)
  759. {
  760.     USES_CONVERSION;
  761.     CRegKey key;
  762.     TCHAR szName[100];
  763.     DWORD dwType,dw = sizeof(szName);
  764.  
  765.     LPOLESTR pszGUID = NULL;
  766.     StringFromCLSID(iid, &pszGUID);
  767.     OutputDebugString(pszClassName);
  768.     OutputDebugString(_T(" - "));
  769.  
  770.     // Attempt to find it in the interfaces section
  771.     key.Open(HKEY_CLASSES_ROOT, _T("Interface"));
  772.     if (key.Open(key, OLE2T(pszGUID)) == S_OK)
  773.     {
  774.         *szName = 0;
  775.         RegQueryValueEx(key.m_hKey, (LPTSTR)NULL, NULL, &dwType, (LPBYTE)szName, &dw);
  776.         OutputDebugString(szName);
  777.         goto cleanup;
  778.     }
  779.     // Attempt to find it in the clsid section
  780.     key.Open(HKEY_CLASSES_ROOT, _T("CLSID"));
  781.     if (key.Open(key, OLE2T(pszGUID)) == S_OK)
  782.     {
  783.         *szName = 0;
  784.         RegQueryValueEx(key.m_hKey, (LPTSTR)NULL, NULL, &dwType, (LPBYTE)szName, &dw);
  785.         OutputDebugString(_T("(CLSID\?\?\?) "));
  786.         OutputDebugString(szName);
  787.         goto cleanup;
  788.     }
  789.     OutputDebugString(OLE2T(pszGUID));
  790. cleanup:
  791.     if (hr != S_OK)
  792.         OutputDebugString(_T(" - failed"));
  793.     OutputDebugString(_T("\n"));
  794.     CoTaskMemFree(pszGUID);
  795.     return hr;
  796. }
  797. #endif
  798.  
  799. HRESULT WINAPI CComObjectRootBase::_Break(void* /* pv */, REFIID iid, void** /* ppvObject */, DWORD /* dw */)
  800. {
  801.     iid;
  802.     _ATLDUMPIID(iid, _T("Break due to QI for interface "), S_OK);
  803.     DebugBreak();
  804.     return S_FALSE;
  805. }
  806.  
  807. HRESULT WINAPI CComObjectRootBase::_NoInterface(void* /* pv */, REFIID /* iid */, void** /* ppvObject */, DWORD /* dw */)
  808. {
  809.     return E_NOINTERFACE;
  810. }
  811.  
  812. HRESULT WINAPI CComObjectRootBase::_Creator(void* pv, REFIID iid, void** ppvObject, DWORD dw)
  813. {
  814.     _ATL_CREATORDATA* pcd = (_ATL_CREATORDATA*)dw;
  815.     return pcd->pFunc(pv, iid, ppvObject);
  816. }
  817.  
  818. HRESULT WINAPI CComObjectRootBase::_Delegate(void* pv, REFIID iid, void** ppvObject, DWORD dw)
  819. {
  820.     HRESULT hRes = E_NOINTERFACE;
  821.     IUnknown* p = *(IUnknown**)((DWORD)pv + dw);
  822.     if (p != NULL)
  823.         hRes = p->QueryInterface(iid, ppvObject);
  824.     return hRes;
  825. }
  826.  
  827. HRESULT WINAPI CComObjectRootBase::_Chain(void* pv, REFIID iid, void** ppvObject, DWORD dw)
  828. {
  829.     _ATL_CHAINDATA* pcd = (_ATL_CHAINDATA*)dw;
  830.     void* p = (void*)((DWORD)pv + pcd->dwOffset);
  831.     return InternalQueryInterface(p, pcd->pFunc(), iid, ppvObject);
  832. }
  833.  
  834. HRESULT WINAPI CComObjectRootBase::_Cache(void* pv, REFIID iid, void** ppvObject, DWORD dw)
  835. {
  836.     HRESULT hRes = E_NOINTERFACE;
  837.     _ATL_CACHEDATA* pcd = (_ATL_CACHEDATA*)dw;
  838.     IUnknown** pp = (IUnknown**)((DWORD)pv + pcd->dwOffsetVar);
  839.     if (*pp == NULL)
  840.         hRes = pcd->pFunc(pv, IID_IUnknown, (void**)pp);
  841.     if (*pp != NULL)
  842.         hRes = (*pp)->QueryInterface(iid, ppvObject);
  843.     return hRes;
  844. }
  845.  
  846. /////////////////////////////////////////////////////////////////////////////
  847. // CComClassFactory
  848.  
  849. STDMETHODIMP CComClassFactory::CreateInstance(LPUNKNOWN pUnkOuter,
  850.     REFIID riid, void** ppvObj)
  851. {
  852.     _ASSERTE(m_pfnCreateInstance != NULL);
  853.     HRESULT hRes = E_POINTER;
  854.     if (ppvObj != NULL)
  855.     {
  856.         *ppvObj = NULL;
  857.         // can't ask for anything other than IUnknown when aggregating
  858.         _ASSERTE((pUnkOuter == NULL) || InlineIsEqualUnknown(riid));
  859.         if ((pUnkOuter != NULL) && !InlineIsEqualUnknown(riid))
  860.             hRes = CLASS_E_NOAGGREGATION;
  861.         else
  862.             hRes = m_pfnCreateInstance(pUnkOuter, riid, ppvObj);
  863.     }
  864.     return hRes;
  865. }
  866.  
  867. STDMETHODIMP CComClassFactory::LockServer(BOOL fLock)
  868. {
  869.     if (fLock)
  870.         _Module.Lock();
  871.     else
  872.         _Module.Unlock();
  873.     return S_OK;
  874. }
  875.  
  876. STDMETHODIMP CComClassFactory2Base::LockServer(BOOL fLock)
  877. {
  878.     if (fLock)
  879.         _Module.Lock();
  880.     else
  881.         _Module.Unlock();
  882.     return S_OK;
  883. }
  884.  
  885. #ifndef _ATL_NO_CONNECTION_POINTS
  886. /////////////////////////////////////////////////////////////////////////////
  887. // Connection Points
  888.  
  889. DWORD CComDynamicUnkArray::Add(IUnknown* pUnk)
  890. {
  891.     IUnknown** pp = NULL;
  892.     if (m_nSize == 0) // no connections
  893.     {
  894.         m_pUnk = pUnk;
  895.         m_nSize = 1;
  896.         return (DWORD)m_pUnk;
  897.     }
  898.     else if (m_nSize == 1)
  899.     {
  900.         //create array
  901.         pp = (IUnknown**)malloc(sizeof(IUnknown*)*_DEFAULT_VECTORLENGTH);
  902.         if (pp == NULL)
  903.             return 0;
  904.         memset(pp, 0, sizeof(IUnknown*)*_DEFAULT_VECTORLENGTH);
  905.         *pp = m_pUnk;
  906.         m_ppUnk = pp;
  907.         m_nSize = _DEFAULT_VECTORLENGTH;
  908.     }
  909.     for (pp = begin();pp<end();pp++)
  910.     {
  911.         if (*pp == NULL)
  912.         {
  913.             *pp = pUnk;
  914.             return (DWORD)pUnk;
  915.         }
  916.     }
  917.     int nAlloc = m_nSize*2;
  918.     pp = (IUnknown**)realloc(m_ppUnk, sizeof(IUnknown*)*nAlloc);
  919.     if (pp == NULL)
  920.         return 0;
  921.     m_ppUnk = pp;
  922.     memset(&m_ppUnk[m_nSize], 0, sizeof(IUnknown*)*m_nSize);
  923.     m_ppUnk[m_nSize] = pUnk;
  924.     m_nSize = nAlloc;
  925.     return (DWORD)pUnk;
  926. }
  927.  
  928. BOOL CComDynamicUnkArray::Remove(DWORD dwCookie)
  929. {
  930.     IUnknown** pp;
  931.     if (dwCookie == NULL)
  932.         return FALSE;
  933.     if (m_nSize == 0)
  934.         return FALSE;
  935.     if (m_nSize == 1)
  936.     {
  937.         if ((DWORD)m_pUnk == dwCookie)
  938.         {
  939.             m_nSize = 0;
  940.             return TRUE;
  941.         }
  942.         return FALSE;
  943.     }
  944.     for (pp=begin();pp<end();pp++)
  945.     {
  946.         if ((DWORD)*pp == dwCookie)
  947.         {
  948.             *pp = NULL;
  949.             return TRUE;
  950.         }
  951.     }
  952.     return FALSE;
  953. }
  954.  
  955. #endif //!_ATL_NO_CONNECTION_POINTS
  956.  
  957. #endif //__ATLCOM_H__
  958.  
  959. /////////////////////////////////////////////////////////////////////////////
  960. // Object Registry Support
  961.  
  962. static HRESULT WINAPI AtlRegisterProgID(LPCTSTR lpszCLSID, LPCTSTR lpszProgID, LPCTSTR lpszUserDesc)
  963. {
  964.     CRegKey keyProgID;
  965.     LONG lRes = keyProgID.Create(HKEY_CLASSES_ROOT, lpszProgID);
  966.     if (lRes == ERROR_SUCCESS)
  967.     {
  968.         keyProgID.SetValue(lpszUserDesc);
  969.         keyProgID.SetKeyValue(_T("CLSID"), lpszCLSID);
  970.         return S_OK;
  971.     }
  972.     return HRESULT_FROM_WIN32(lRes);
  973. }
  974.  
  975. void CComModule::AddCreateWndData(_AtlCreateWndData* pData, void* pObject)
  976. {
  977.     pData->m_pThis = pObject;
  978.     pData->m_dwThreadID = ::GetCurrentThreadId();
  979.     ::EnterCriticalSection(&m_csWindowCreate);
  980.     pData->m_pNext = m_pCreateWndList;
  981.     m_pCreateWndList = pData;
  982.     ::LeaveCriticalSection(&m_csWindowCreate);
  983. }
  984.  
  985. void* CComModule::ExtractCreateWndData()
  986. {
  987.     ::EnterCriticalSection(&m_csWindowCreate);
  988.     _AtlCreateWndData* pEntry = m_pCreateWndList;
  989.     if(pEntry == NULL)
  990.     {
  991.         ::LeaveCriticalSection(&m_csWindowCreate);
  992.         return NULL;
  993.     }
  994.  
  995.     DWORD dwThreadID = ::GetCurrentThreadId();
  996.     _AtlCreateWndData* pPrev = NULL;
  997.     while(pEntry != NULL)
  998.     {
  999.         if(pEntry->m_dwThreadID == dwThreadID)
  1000.         {
  1001.             if(pPrev == NULL)
  1002.                 m_pCreateWndList = pEntry->m_pNext;
  1003.             else
  1004.                 pPrev->m_pNext = pEntry->m_pNext;
  1005.             ::LeaveCriticalSection(&m_csWindowCreate);
  1006.             return pEntry->m_pThis;
  1007.         }
  1008.         pPrev = pEntry;
  1009.         pEntry = pEntry->m_pNext;
  1010.     }
  1011.  
  1012.     ::LeaveCriticalSection(&m_csWindowCreate);
  1013.     return NULL;
  1014. }
  1015.  
  1016. #ifdef _ATL_STATIC_REGISTRY
  1017. // Statically linking to Registry Ponent
  1018. HRESULT WINAPI CComModule::UpdateRegistryFromResourceS(UINT nResID, BOOL bRegister,
  1019.     struct _ATL_REGMAP_ENTRY* pMapEntries)
  1020. {
  1021.     USES_CONVERSION;
  1022.     CRegObject ro;
  1023.     TCHAR szModule[_MAX_PATH];
  1024.     GetModuleFileName(_Module.GetModuleInstance(), szModule, _MAX_PATH);
  1025.     LPOLESTR pszModule = T2OLE(szModule);
  1026.     ro.AddReplacement(OLESTR("Module"), pszModule);
  1027.     if (NULL != pMapEntries)
  1028.     {
  1029.         while (NULL != pMapEntries->szKey)
  1030.         {
  1031.             _ASSERTE(NULL != pMapEntries->szData);
  1032.             ro.AddReplacement(pMapEntries->szKey, pMapEntries->szData);
  1033.             pMapEntries++;
  1034.         }
  1035.     }
  1036.  
  1037.     LPCOLESTR szType = OLESTR("REGISTRY");
  1038.     return (bRegister) ? ro.ResourceRegister(pszModule, nResID, szType) :
  1039.             ro.ResourceUnregister(pszModule, nResID, szType);
  1040. }
  1041.  
  1042. HRESULT WINAPI CComModule::UpdateRegistryFromResourceS(LPCTSTR lpszRes, BOOL bRegister,
  1043.     struct _ATL_REGMAP_ENTRY* pMapEntries)
  1044. {
  1045.     USES_CONVERSION;
  1046.     CRegObject ro;
  1047.     TCHAR szModule[_MAX_PATH];
  1048.     GetModuleFileName(_Module.GetModuleInstance(), szModule, _MAX_PATH);
  1049.     LPOLESTR pszModule = T2OLE(szModule);
  1050.     ro.AddReplacement(OLESTR("Module"), pszModule);
  1051.     if (NULL != pMapEntries)
  1052.     {
  1053.         while (NULL != pMapEntries->szKey)
  1054.         {
  1055.             _ASSERTE(NULL != pMapEntries->szData);
  1056.             ro.AddReplacement(pMapEntries->szKey, pMapEntries->szData);
  1057.             pMapEntries++;
  1058.         }
  1059.     }
  1060.  
  1061.     LPCOLESTR szType = OLESTR("REGISTRY");
  1062.     LPCOLESTR pszRes = T2COLE(lpszRes);
  1063.     return (bRegister) ? ro.ResourceRegisterSz(pszModule, pszRes, szType) :
  1064.             ro.ResourceUnregisterSz(pszModule, pszRes, szType);
  1065. }
  1066. #endif // _ATL_STATIC_REGISTRY
  1067.  
  1068. HRESULT WINAPI CComModule::UpdateRegistryClass(const CLSID& clsid, LPCTSTR lpszProgID,
  1069.     LPCTSTR lpszVerIndProgID, UINT nDescID, DWORD dwFlags, BOOL bRegister)
  1070. {
  1071.     if (bRegister)
  1072.     {
  1073.         return RegisterClassHelper(clsid, lpszProgID, lpszVerIndProgID, nDescID,
  1074.             dwFlags);
  1075.     }
  1076.     else
  1077.         return UnregisterClassHelper(clsid, lpszProgID, lpszVerIndProgID);
  1078. }
  1079.  
  1080. HRESULT WINAPI CComModule::RegisterClassHelper(const CLSID& clsid, LPCTSTR lpszProgID,
  1081.     LPCTSTR lpszVerIndProgID, UINT nDescID, DWORD dwFlags)
  1082. {
  1083.     static const TCHAR szProgID[] = _T("ProgID");
  1084.     static const TCHAR szVIProgID[] = _T("VersionIndependentProgID");
  1085.     static const TCHAR szLS32[] = _T("LocalServer32");
  1086.     static const TCHAR szIPS32[] = _T("InprocServer32");
  1087.     static const TCHAR szThreadingModel[] = _T("ThreadingModel");
  1088.     static const TCHAR szAUTPRX32[] = _T("AUTPRX32.DLL");
  1089.     static const TCHAR szApartment[] = _T("Apartment");
  1090.     static const TCHAR szBoth[] = _T("both");
  1091.     USES_CONVERSION;
  1092.     HRESULT hRes = S_OK;
  1093.     TCHAR szDesc[256];
  1094.     LoadString(m_hInst, nDescID, szDesc, 256);
  1095.     TCHAR szModule[_MAX_PATH];
  1096.     GetModuleFileName(m_hInst, szModule, _MAX_PATH);
  1097.  
  1098.     LPOLESTR lpOleStr;
  1099.     StringFromCLSID(clsid, &lpOleStr);
  1100.     LPTSTR lpsz = OLE2T(lpOleStr);
  1101.  
  1102.     hRes = AtlRegisterProgID(lpsz, lpszProgID, szDesc);
  1103.     if (hRes == S_OK)
  1104.         hRes = AtlRegisterProgID(lpsz, lpszVerIndProgID, szDesc);
  1105.     LONG lRes = ERROR_SUCCESS;
  1106.     if (hRes == S_OK)
  1107.     {
  1108.         CRegKey key;
  1109.         LONG lRes = key.Open(HKEY_CLASSES_ROOT, _T("CLSID"));
  1110.         if (lRes == ERROR_SUCCESS)
  1111.         {
  1112.             lRes = key.Create(key, lpsz);
  1113.             if (lRes == ERROR_SUCCESS)
  1114.             {
  1115.                 key.SetValue(szDesc);
  1116.                 key.SetKeyValue(szProgID, lpszProgID);
  1117.                 key.SetKeyValue(szVIProgID, lpszVerIndProgID);
  1118.  
  1119.                 if ((m_hInst == NULL) || (m_hInst == GetModuleHandle(NULL))) // register as EXE
  1120.                     key.SetKeyValue(szLS32, szModule);
  1121.                 else
  1122.                 {
  1123.                     key.SetKeyValue(szIPS32, (dwFlags & AUTPRXFLAG) ? szAUTPRX32 : szModule);
  1124.                     LPCTSTR lpszModel = (dwFlags & THREADFLAGS_BOTH) ? szBoth :
  1125.                         (dwFlags & THREADFLAGS_APARTMENT) ? szApartment : NULL;
  1126.                     if (lpszModel != NULL)
  1127.                         key.SetKeyValue(szIPS32, lpszModel, szThreadingModel);
  1128.                 }
  1129.             }
  1130.         }
  1131.     }
  1132.     CoTaskMemFree(lpOleStr);
  1133.     if (lRes != ERROR_SUCCESS)
  1134.         hRes = HRESULT_FROM_WIN32(lRes);
  1135.     return hRes;
  1136. }
  1137.  
  1138. HRESULT WINAPI CComModule::UnregisterClassHelper(const CLSID& clsid, LPCTSTR lpszProgID,
  1139.     LPCTSTR lpszVerIndProgID)
  1140. {
  1141.     USES_CONVERSION;
  1142.     CRegKey key;
  1143.  
  1144.     key.Attach(HKEY_CLASSES_ROOT);
  1145.     if (lpszProgID != NULL && lstrcmpi(lpszProgID, _T("")))
  1146.         key.RecurseDeleteKey(lpszProgID);
  1147.     if (lpszVerIndProgID != NULL && lstrcmpi(lpszVerIndProgID, _T("")))
  1148.         key.RecurseDeleteKey(lpszVerIndProgID);
  1149.     LPOLESTR lpOleStr;
  1150.     StringFromCLSID(clsid, &lpOleStr);
  1151.     LPTSTR lpsz = OLE2T(lpOleStr);
  1152.     if (key.Open(key, _T("CLSID")) == ERROR_SUCCESS)
  1153.         key.RecurseDeleteKey(lpsz);
  1154.     CoTaskMemFree(lpOleStr);
  1155.     return S_OK;
  1156. }
  1157.  
  1158.  
  1159. /////////////////////////////////////////////////////////////////////////////
  1160. // CRegKey
  1161.  
  1162. LONG CRegKey::Close()
  1163. {
  1164.     LONG lRes = ERROR_SUCCESS;
  1165.     if (m_hKey != NULL)
  1166.     {
  1167.         lRes = RegCloseKey(m_hKey);
  1168.         m_hKey = NULL;
  1169.     }
  1170.     return lRes;
  1171. }
  1172.  
  1173. LONG CRegKey::Create(HKEY hKeyParent, LPCTSTR lpszKeyName,
  1174.     LPTSTR lpszClass, DWORD dwOptions, REGSAM samDesired,
  1175.     LPSECURITY_ATTRIBUTES lpSecAttr, LPDWORD lpdwDisposition)
  1176. {
  1177.     _ASSERTE(hKeyParent != NULL);
  1178.     DWORD dw;
  1179.     HKEY hKey = NULL;
  1180.     LONG lRes = RegCreateKeyEx(hKeyParent, lpszKeyName, 0,
  1181.         lpszClass, dwOptions, samDesired, lpSecAttr, &hKey, &dw);
  1182.     if (lpdwDisposition != NULL)
  1183.         *lpdwDisposition = dw;
  1184.     if (lRes == ERROR_SUCCESS)
  1185.     {
  1186.         lRes = Close();
  1187.         m_hKey = hKey;
  1188.     }
  1189.     return lRes;
  1190. }
  1191.  
  1192. LONG CRegKey::Open(HKEY hKeyParent, LPCTSTR lpszKeyName, REGSAM samDesired)
  1193. {
  1194.     _ASSERTE(hKeyParent != NULL);
  1195.     HKEY hKey = NULL;
  1196.     LONG lRes = RegOpenKeyEx(hKeyParent, lpszKeyName, 0, samDesired, &hKey);
  1197.     if (lRes == ERROR_SUCCESS)
  1198.     {
  1199.         lRes = Close();
  1200.         _ASSERTE(lRes == ERROR_SUCCESS);
  1201.         m_hKey = hKey;
  1202.     }
  1203.     return lRes;
  1204. }
  1205.  
  1206. LONG CRegKey::QueryValue(DWORD& dwValue, LPCTSTR lpszValueName)
  1207. {
  1208.     DWORD dwType = NULL;
  1209.     DWORD dwCount = sizeof(DWORD);
  1210.     LONG lRes = RegQueryValueEx(m_hKey, (LPTSTR)lpszValueName, NULL, &dwType,
  1211.         (LPBYTE)&dwValue, &dwCount);
  1212.     _ASSERTE((lRes!=ERROR_SUCCESS) || (dwType == REG_DWORD));
  1213.     _ASSERTE((lRes!=ERROR_SUCCESS) || (dwCount == sizeof(DWORD)));
  1214.     return lRes;
  1215. }
  1216.  
  1217. LONG CRegKey::QueryValue(LPTSTR szValue, LPCTSTR lpszValueName, DWORD* pdwCount)
  1218. {
  1219.     _ASSERTE(pdwCount != NULL);
  1220.     DWORD dwType = NULL;
  1221.     LONG lRes = RegQueryValueEx(m_hKey, (LPTSTR)lpszValueName, NULL, &dwType,
  1222.         (LPBYTE)szValue, pdwCount);
  1223.     _ASSERTE((lRes!=ERROR_SUCCESS) || (dwType == REG_SZ) ||
  1224.              (dwType == REG_MULTI_SZ) || (dwType == REG_EXPAND_SZ));
  1225.     return lRes;
  1226. }
  1227.  
  1228. LONG WINAPI CRegKey::SetValue(HKEY hKeyParent, LPCTSTR lpszKeyName, LPCTSTR lpszValue, LPCTSTR lpszValueName)
  1229. {
  1230.     _ASSERTE(lpszValue != NULL);
  1231.     CRegKey key;
  1232.     LONG lRes = key.Create(hKeyParent, lpszKeyName);
  1233.     if (lRes == ERROR_SUCCESS)
  1234.         lRes = key.SetValue(lpszValue, lpszValueName);
  1235.     return lRes;
  1236. }
  1237.  
  1238. LONG CRegKey::SetKeyValue(LPCTSTR lpszKeyName, LPCTSTR lpszValue, LPCTSTR lpszValueName)
  1239. {
  1240.     _ASSERTE(lpszValue != NULL);
  1241.     CRegKey key;
  1242.     LONG lRes = key.Create(m_hKey, lpszKeyName);
  1243.     if (lRes == ERROR_SUCCESS)
  1244.         lRes = key.SetValue(lpszValue, lpszValueName);
  1245.     return lRes;
  1246. }
  1247.  
  1248. LONG CRegKey::SetValue(DWORD dwValue, LPCTSTR lpszValueName)
  1249. {
  1250.     _ASSERTE(m_hKey != NULL);
  1251.     return RegSetValueEx(m_hKey, lpszValueName, NULL, REG_DWORD,
  1252.         (BYTE * const)&dwValue, sizeof(DWORD));
  1253. }
  1254.  
  1255. HRESULT CRegKey::SetValue(LPCTSTR lpszValue, LPCTSTR lpszValueName)
  1256. {
  1257.     _ASSERTE(lpszValue != NULL);
  1258.     _ASSERTE(m_hKey != NULL);
  1259.     return RegSetValueEx(m_hKey, lpszValueName, NULL, REG_SZ,
  1260.         (BYTE * const)lpszValue, (lstrlen(lpszValue)+1)*sizeof(TCHAR));
  1261. }
  1262.  
  1263. //RecurseDeleteKey is necessary because on NT RegDeleteKey doesn't work if the
  1264. //specified key has subkeys
  1265. LONG CRegKey::RecurseDeleteKey(LPCTSTR lpszKey)
  1266. {
  1267.     CRegKey key;
  1268.     LONG lRes = key.Open(m_hKey, lpszKey);
  1269.     if (lRes != ERROR_SUCCESS)
  1270.         return lRes;
  1271.     FILETIME time;
  1272.     TCHAR szBuffer[256];
  1273.     DWORD dwSize = 256;
  1274.     while (RegEnumKeyEx(key.m_hKey, 0, szBuffer, &dwSize, NULL, NULL, NULL,
  1275.         &time)==ERROR_SUCCESS)
  1276.     {
  1277.         lRes = key.RecurseDeleteKey(szBuffer);
  1278.         if (lRes != ERROR_SUCCESS)
  1279.             return lRes;
  1280.         dwSize = 256;
  1281.     }
  1282.     key.Close();
  1283.     return DeleteSubKey(lpszKey);
  1284. }
  1285.  
  1286. #ifdef __ATLCOM_H__
  1287. #ifndef _ATL_NO_SECURITY
  1288.  
  1289. CSecurityDescriptor::CSecurityDescriptor()
  1290. {
  1291.     m_pSD = NULL;
  1292.     m_pOwner = NULL;
  1293.     m_pGroup = NULL;
  1294.     m_pDACL = NULL;
  1295.     m_pSACL= NULL;
  1296. }
  1297.  
  1298. CSecurityDescriptor::~CSecurityDescriptor()
  1299. {
  1300.     if (m_pSD)
  1301.         delete m_pSD;
  1302.     if (m_pOwner)
  1303.         free(m_pOwner);
  1304.     if (m_pGroup)
  1305.         free(m_pGroup);
  1306.     if (m_pDACL)
  1307.         free(m_pDACL);
  1308.     if (m_pSACL)
  1309.         free(m_pSACL);
  1310. }
  1311.  
  1312. HRESULT CSecurityDescriptor::Initialize()
  1313. {
  1314.     if (m_pSD)
  1315.     {
  1316.         delete m_pSD;
  1317.         m_pSD = NULL;
  1318.     }
  1319.     if (m_pOwner)
  1320.     {
  1321.         free(m_pOwner);
  1322.         m_pOwner = NULL;
  1323.     }
  1324.     if (m_pGroup)
  1325.     {
  1326.         free(m_pGroup);
  1327.         m_pGroup = NULL;
  1328.     }
  1329.     if (m_pDACL)
  1330.     {
  1331.         free(m_pDACL);
  1332.         m_pDACL = NULL;
  1333.     }
  1334.     if (m_pSACL)
  1335.     {
  1336.         free(m_pSACL);
  1337.         m_pSACL = NULL;
  1338.     }
  1339.  
  1340.     ATLTRY(m_pSD = new SECURITY_DESCRIPTOR);
  1341.     if (!m_pSD)
  1342.         return E_OUTOFMEMORY;
  1343.     if (!InitializeSecurityDescriptor(m_pSD, SECURITY_DESCRIPTOR_REVISION))
  1344.     {
  1345.         HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
  1346.         delete m_pSD;
  1347.         m_pSD = NULL;
  1348.         _ASSERTE(FALSE);
  1349.         return hr;
  1350.     }
  1351.     // Set the DACL to allow EVERYONE
  1352.     SetSecurityDescriptorDacl(m_pSD, TRUE, NULL, FALSE);
  1353.     return S_OK;
  1354. }
  1355.  
  1356. HRESULT CSecurityDescriptor::InitializeFromProcessToken(BOOL bDefaulted)
  1357. {
  1358.     PSID pUserSid;
  1359.     PSID pGroupSid;
  1360.     HRESULT hr;
  1361.  
  1362.     Initialize();
  1363.     hr = GetProcessSids(&pUserSid, &pGroupSid);
  1364.     if (FAILED(hr))
  1365.         return hr;
  1366.     hr = SetOwner(pUserSid, bDefaulted);
  1367.     if (FAILED(hr))
  1368.         return hr;
  1369.     hr = SetGroup(pGroupSid, bDefaulted);
  1370.     if (FAILED(hr))
  1371.         return hr;
  1372.     return S_OK;
  1373. }
  1374.  
  1375. HRESULT CSecurityDescriptor::InitializeFromThreadToken(BOOL bDefaulted, BOOL bRevertToProcessToken)
  1376. {
  1377.     PSID pUserSid;
  1378.     PSID pGroupSid;
  1379.     HRESULT hr;
  1380.  
  1381.     Initialize();
  1382.     hr = GetThreadSids(&pUserSid, &pGroupSid);
  1383.     if (HRESULT_CODE(hr) == ERROR_NO_TOKEN && bRevertToProcessToken)
  1384.         hr = GetProcessSids(&pUserSid, &pGroupSid);
  1385.     if (FAILED(hr))
  1386.         return hr;
  1387.     hr = SetOwner(pUserSid, bDefaulted);
  1388.     if (FAILED(hr))
  1389.         return hr;
  1390.     hr = SetGroup(pGroupSid, bDefaulted);
  1391.     if (FAILED(hr))
  1392.         return hr;
  1393.     return S_OK;
  1394. }
  1395.  
  1396. HRESULT CSecurityDescriptor::SetOwner(PSID pOwnerSid, BOOL bDefaulted)
  1397. {
  1398.     _ASSERTE(m_pSD);
  1399.  
  1400.     // Mark the SD as having no owner
  1401.     if (!SetSecurityDescriptorOwner(m_pSD, NULL, bDefaulted))
  1402.     {
  1403.         HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
  1404.         _ASSERTE(FALSE);
  1405.         return hr;
  1406.     }
  1407.  
  1408.     if (m_pOwner)
  1409.     {
  1410.         free(m_pOwner);
  1411.         m_pOwner = NULL;
  1412.     }
  1413.  
  1414.     // If they asked for no owner don't do the copy
  1415.     if (pOwnerSid == NULL)
  1416.         return S_OK;
  1417.  
  1418.     // Make a copy of the Sid for the return value
  1419.     DWORD dwSize = GetLengthSid(pOwnerSid);
  1420.  
  1421.     m_pOwner = (PSID) malloc(dwSize);
  1422.     if (!m_pOwner)
  1423.     {
  1424.         // Insufficient memory to allocate Sid
  1425.         _ASSERTE(FALSE);
  1426.         return E_OUTOFMEMORY;
  1427.     }
  1428.     if (!CopySid(dwSize, m_pOwner, pOwnerSid))
  1429.     {
  1430.         HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
  1431.         _ASSERTE(FALSE);
  1432.         free(m_pOwner);
  1433.         m_pOwner = NULL;
  1434.         return hr;
  1435.     }
  1436.  
  1437.     _ASSERTE(IsValidSid(m_pOwner));
  1438.  
  1439.     if (!SetSecurityDescriptorOwner(m_pSD, m_pOwner, bDefaulted))
  1440.     {
  1441.         HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
  1442.         _ASSERTE(FALSE);
  1443.         free(m_pOwner);
  1444.         m_pOwner = NULL;
  1445.         return hr;
  1446.     }
  1447.  
  1448.     return S_OK;
  1449. }
  1450.  
  1451. HRESULT CSecurityDescriptor::SetGroup(PSID pGroupSid, BOOL bDefaulted)
  1452. {
  1453.     _ASSERTE(m_pSD);
  1454.  
  1455.     // Mark the SD as having no Group
  1456.     if (!SetSecurityDescriptorGroup(m_pSD, NULL, bDefaulted))
  1457.     {
  1458.         HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
  1459.         _ASSERTE(FALSE);
  1460.         return hr;
  1461.     }
  1462.  
  1463.     if (m_pGroup)
  1464.     {
  1465.         free(m_pGroup);
  1466.         m_pGroup = NULL;
  1467.     }
  1468.  
  1469.     // If they asked for no Group don't do the copy
  1470.     if (pGroupSid == NULL)
  1471.         return S_OK;
  1472.  
  1473.     // Make a copy of the Sid for the return value
  1474.     DWORD dwSize = GetLengthSid(pGroupSid);
  1475.  
  1476.     m_pGroup = (PSID) malloc(dwSize);
  1477.     if (!m_pGroup)
  1478.     {
  1479.         // Insufficient memory to allocate Sid
  1480.         _ASSERTE(FALSE);
  1481.         return E_OUTOFMEMORY;
  1482.     }
  1483.     if (!CopySid(dwSize, m_pGroup, pGroupSid))
  1484.     {
  1485.         HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
  1486.         _ASSERTE(FALSE);
  1487.         free(m_pGroup);
  1488.         m_pGroup = NULL;
  1489.         return hr;
  1490.     }
  1491.  
  1492.     _ASSERTE(IsValidSid(m_pGroup));
  1493.  
  1494.     if (!SetSecurityDescriptorGroup(m_pSD, m_pGroup, bDefaulted))
  1495.     {
  1496.         HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
  1497.         _ASSERTE(FALSE);
  1498.         free(m_pGroup);
  1499.         m_pGroup = NULL;
  1500.         return hr;
  1501.     }
  1502.  
  1503.     return S_OK;
  1504. }
  1505.  
  1506. HRESULT CSecurityDescriptor::Allow(LPCTSTR pszPrincipal, DWORD dwAccessMask)
  1507. {
  1508.     HRESULT hr = AddAccessAllowedACEToACL(&m_pDACL, pszPrincipal, dwAccessMask);
  1509.     if (SUCCEEDED(hr))
  1510.         SetSecurityDescriptorDacl(m_pSD, TRUE, m_pDACL, FALSE);
  1511.     return hr;
  1512. }
  1513.  
  1514. HRESULT CSecurityDescriptor::Deny(LPCTSTR pszPrincipal, DWORD dwAccessMask)
  1515. {
  1516.     HRESULT hr = AddAccessDeniedACEToACL(&m_pDACL, pszPrincipal, dwAccessMask);
  1517.     if (SUCCEEDED(hr))
  1518.         SetSecurityDescriptorDacl(m_pSD, TRUE, m_pDACL, FALSE);
  1519.     return hr;
  1520. }
  1521.  
  1522. HRESULT CSecurityDescriptor::Revoke(LPCTSTR pszPrincipal)
  1523. {
  1524.     HRESULT hr = RemovePrincipalFromACL(m_pDACL, pszPrincipal);
  1525.     if (SUCCEEDED(hr))
  1526.         SetSecurityDescriptorDacl(m_pSD, TRUE, m_pDACL, FALSE);
  1527.     return hr;
  1528. }
  1529.  
  1530. HRESULT CSecurityDescriptor::GetProcessSids(PSID* ppUserSid, PSID* ppGroupSid)
  1531. {
  1532.     BOOL bRes;
  1533.     HRESULT hr;
  1534.     HANDLE hToken = NULL;
  1535.     if (ppUserSid)
  1536.         *ppUserSid = NULL;
  1537.     if (ppGroupSid)
  1538.         *ppGroupSid = NULL;
  1539.     bRes = OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &hToken);
  1540.     if (!bRes)
  1541.     {
  1542.         // Couldn't open process token
  1543.         hr = HRESULT_FROM_WIN32(GetLastError());
  1544.         _ASSERTE(FALSE);
  1545.         return hr;
  1546.     }
  1547.     hr = GetTokenSids(hToken, ppUserSid, ppGroupSid);
  1548.     return hr;
  1549. }
  1550.  
  1551. HRESULT CSecurityDescriptor::GetThreadSids(PSID* ppUserSid, PSID* ppGroupSid, BOOL bOpenAsSelf)
  1552. {
  1553.     BOOL bRes;
  1554.     HRESULT hr;
  1555.     HANDLE hToken = NULL;
  1556.     if (ppUserSid)
  1557.         *ppUserSid = NULL;
  1558.     if (ppGroupSid)
  1559.         *ppGroupSid = NULL;
  1560.     bRes = OpenThreadToken(GetCurrentThread(), TOKEN_QUERY, bOpenAsSelf, &hToken);
  1561.     if (!bRes)
  1562.     {
  1563.         // Couldn't open thread token
  1564.         hr = HRESULT_FROM_WIN32(GetLastError());
  1565.         return hr;
  1566.     }
  1567.     hr = GetTokenSids(hToken, ppUserSid, ppGroupSid);
  1568.     return hr;
  1569. }
  1570.  
  1571.  
  1572. HRESULT CSecurityDescriptor::GetTokenSids(HANDLE hToken, PSID* ppUserSid, PSID* ppGroupSid)
  1573. {
  1574.     DWORD dwSize;
  1575.     HRESULT hr;
  1576.     PTOKEN_USER ptkUser = NULL;
  1577.     PTOKEN_PRIMARY_GROUP ptkGroup = NULL;
  1578.  
  1579.     if (ppUserSid)
  1580.         *ppUserSid = NULL;
  1581.     if (ppGroupSid)
  1582.         *ppGroupSid = NULL;
  1583.  
  1584.     if (ppUserSid)
  1585.     {
  1586.         // Get length required for TokenUser by specifying buffer length of 0
  1587.         GetTokenInformation(hToken, TokenUser, NULL, 0, &dwSize);
  1588.         hr = GetLastError();
  1589.         if (hr != ERROR_INSUFFICIENT_BUFFER)
  1590.         {
  1591.             // Expected ERROR_INSUFFICIENT_BUFFER
  1592.             _ASSERTE(FALSE);
  1593.             hr = HRESULT_FROM_WIN32(hr);
  1594.             goto failed;
  1595.         }
  1596.  
  1597.         ptkUser = (TOKEN_USER*) malloc(dwSize);
  1598.         if (!ptkUser)
  1599.         {
  1600.             // Insufficient memory to allocate TOKEN_USER
  1601.             _ASSERTE(FALSE);
  1602.             hr = E_OUTOFMEMORY;
  1603.             goto failed;
  1604.         }
  1605.         // Get Sid of process token.
  1606.         if (!GetTokenInformation(hToken, TokenUser, ptkUser, dwSize, &dwSize))
  1607.         {
  1608.             // Couldn't get user info
  1609.             hr = HRESULT_FROM_WIN32(GetLastError());
  1610.             _ASSERTE(FALSE);
  1611.             goto failed;
  1612.         }
  1613.  
  1614.         // Make a copy of the Sid for the return value
  1615.         dwSize = GetLengthSid(ptkUser->User.Sid);
  1616.  
  1617.         PSID pSid = (PSID) malloc(dwSize);
  1618.         if (!pSid)
  1619.         {
  1620.             // Insufficient memory to allocate Sid
  1621.             _ASSERTE(FALSE);
  1622.             hr = E_OUTOFMEMORY;
  1623.             goto failed;
  1624.         }
  1625.         if (!CopySid(dwSize, pSid, ptkUser->User.Sid))
  1626.         {
  1627.             hr = HRESULT_FROM_WIN32(GetLastError());
  1628.             _ASSERTE(FALSE);
  1629.             goto failed;
  1630.         }
  1631.  
  1632.         _ASSERTE(IsValidSid(pSid));
  1633.         *ppUserSid = pSid;
  1634.         free(ptkUser);
  1635.     }
  1636.     if (ppGroupSid)
  1637.     {
  1638.         // Get length required for TokenPrimaryGroup by specifying buffer length of 0
  1639.         GetTokenInformation(hToken, TokenPrimaryGroup, NULL, 0, &dwSize);
  1640.         hr = GetLastError();
  1641.         if (hr != ERROR_INSUFFICIENT_BUFFER)
  1642.         {
  1643.             // Expected ERROR_INSUFFICIENT_BUFFER
  1644.             _ASSERTE(FALSE);
  1645.             hr = HRESULT_FROM_WIN32(hr);
  1646.             goto failed;
  1647.         }
  1648.  
  1649.         ptkGroup = (TOKEN_PRIMARY_GROUP*) malloc(dwSize);
  1650.         if (!ptkGroup)
  1651.         {
  1652.             // Insufficient memory to allocate TOKEN_USER
  1653.             _ASSERTE(FALSE);
  1654.             hr = E_OUTOFMEMORY;
  1655.             goto failed;
  1656.         }
  1657.         // Get Sid of process token.
  1658.         if (!GetTokenInformation(hToken, TokenPrimaryGroup, ptkGroup, dwSize, &dwSize))
  1659.         {
  1660.             // Couldn't get user info
  1661.             hr = HRESULT_FROM_WIN32(GetLastError());
  1662.             _ASSERTE(FALSE);
  1663.             goto failed;
  1664.         }
  1665.  
  1666.         // Make a copy of the Sid for the return value
  1667.         dwSize = GetLengthSid(ptkGroup->PrimaryGroup);
  1668.  
  1669.         PSID pSid = (PSID) malloc(dwSize);
  1670.         if (!pSid)
  1671.         {
  1672.             // Insufficient memory to allocate Sid
  1673.             _ASSERTE(FALSE);
  1674.             hr = E_OUTOFMEMORY;
  1675.             goto failed;
  1676.         }
  1677.         if (!CopySid(dwSize, pSid, ptkGroup->PrimaryGroup))
  1678.         {
  1679.             hr = HRESULT_FROM_WIN32(GetLastError());
  1680.             _ASSERTE(FALSE);
  1681.             goto failed;
  1682.         }
  1683.  
  1684.         _ASSERTE(IsValidSid(pSid));
  1685.  
  1686.         *ppGroupSid = pSid;
  1687.         free(ptkGroup);
  1688.     }
  1689.  
  1690.     return S_OK;
  1691.  
  1692. failed:
  1693.     if (ptkUser)
  1694.         free(ptkUser);
  1695.     if (ptkGroup)
  1696.         free (ptkGroup);
  1697.     return hr;
  1698. }
  1699.  
  1700.  
  1701. HRESULT CSecurityDescriptor::GetCurrentUserSID(PSID *ppSid)
  1702. {
  1703.     HANDLE tkHandle;
  1704.  
  1705.     if (OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &tkHandle))
  1706.     {
  1707.         TOKEN_USER *tkUser;
  1708.         DWORD tkSize;
  1709.         DWORD sidLength;
  1710.  
  1711.         // Call to get size information for alloc
  1712.         GetTokenInformation(tkHandle, TokenUser, NULL, 0, &tkSize);
  1713.         tkUser = (TOKEN_USER *) malloc(tkSize);
  1714.  
  1715.         // Now make the real call
  1716.         if (GetTokenInformation(tkHandle, TokenUser, tkUser, tkSize, &tkSize))
  1717.         {
  1718.             sidLength = GetLengthSid(tkUser->User.Sid);
  1719.             *ppSid = (PSID) malloc(sidLength);
  1720.  
  1721.             memcpy(*ppSid, tkUser->User.Sid, sidLength);
  1722.             CloseHandle(tkHandle);
  1723.  
  1724.             free(tkUser);
  1725.             return S_OK;
  1726.         }
  1727.         else
  1728.         {
  1729.             free(tkUser);
  1730.             return HRESULT_FROM_WIN32(GetLastError());
  1731.         }
  1732.     }
  1733.     return HRESULT_FROM_WIN32(GetLastError());
  1734. }
  1735.  
  1736.  
  1737. HRESULT CSecurityDescriptor::GetPrincipalSID(LPCTSTR pszPrincipal, PSID *ppSid)
  1738. {
  1739.     HRESULT hr;
  1740.     LPTSTR pszRefDomain = NULL;
  1741.     DWORD dwDomainSize = 0;
  1742.     DWORD dwSidSize = 0;
  1743.     SID_NAME_USE snu;
  1744.  
  1745.     // Call to get size info for alloc
  1746.     LookupAccountName(NULL, pszPrincipal, *ppSid, &dwSidSize, pszRefDomain, &dwDomainSize, &snu);
  1747.  
  1748.     hr = GetLastError();
  1749.     if (hr != ERROR_INSUFFICIENT_BUFFER)
  1750.         return HRESULT_FROM_WIN32(hr);
  1751.  
  1752.     ATLTRY(pszRefDomain = new TCHAR[dwDomainSize]);
  1753.     if (pszRefDomain == NULL)
  1754.         return E_OUTOFMEMORY;
  1755.  
  1756.     *ppSid = (PSID) malloc(dwSidSize);
  1757.     if (*ppSid != NULL)
  1758.     {
  1759.         if (!LookupAccountName(NULL, pszPrincipal, *ppSid, &dwSidSize, pszRefDomain, &dwDomainSize, &snu))
  1760.         {
  1761.             free(*ppSid);
  1762.             *ppSid = NULL;
  1763.             delete[] pszRefDomain;
  1764.             return HRESULT_FROM_WIN32(GetLastError());
  1765.         }
  1766.         delete[] pszRefDomain;
  1767.         return S_OK;
  1768.     }
  1769.     delete[] pszRefDomain;
  1770.     return E_OUTOFMEMORY;
  1771. }
  1772.  
  1773.  
  1774. HRESULT CSecurityDescriptor::Attach(PSECURITY_DESCRIPTOR pSelfRelativeSD)
  1775. {
  1776.     PACL    pDACL = NULL;
  1777.     PACL    pSACL = NULL;
  1778.     BOOL    bDACLPresent, bSACLPresent;
  1779.     BOOL    bDefaulted;
  1780.     PACL    m_pDACL = NULL;
  1781.     ACCESS_ALLOWED_ACE* pACE;
  1782.     HRESULT hr;
  1783.     PSID    pUserSid;
  1784.     PSID    pGroupSid;
  1785.  
  1786.     hr = Initialize();
  1787.     if(FAILED(hr))
  1788.         return hr;
  1789.  
  1790.     // get the existing DACL.
  1791.     if (!GetSecurityDescriptorDacl(pSelfRelativeSD, &bDACLPresent, &pDACL, &bDefaulted))
  1792.         goto failed;
  1793.  
  1794.     if (bDACLPresent)
  1795.     {
  1796.         if (pDACL)
  1797.         {
  1798.             // allocate new DACL.
  1799.             m_pDACL = (PACL) malloc(pDACL->AclSize);
  1800.             if (!m_pDACL)
  1801.                 goto failed;
  1802.  
  1803.             // initialize the DACL
  1804.             if (!InitializeAcl(m_pDACL, pDACL->AclSize, ACL_REVISION))
  1805.                 goto failed;
  1806.  
  1807.             // copy the ACES
  1808.             for (int i = 0; i < pDACL->AceCount; i++)
  1809.             {
  1810.                 if (!GetAce(pDACL, i, (void **)&pACE))
  1811.                     goto failed;
  1812.  
  1813.                 if (!AddAccessAllowedAce(m_pDACL, ACL_REVISION, pACE->Mask, (PSID)&(pACE->SidStart)))
  1814.                     goto failed;
  1815.             }
  1816.  
  1817.             if (!IsValidAcl(m_pDACL))
  1818.                 goto failed;
  1819.         }
  1820.  
  1821.         // set the DACL
  1822.         if (!SetSecurityDescriptorDacl(m_pSD, m_pDACL ? TRUE : FALSE, m_pDACL, bDefaulted))
  1823.             goto failed;
  1824.     }
  1825.  
  1826.     // get the existing SACL.
  1827.     if (!GetSecurityDescriptorSacl(pSelfRelativeSD, &bSACLPresent, &pSACL, &bDefaulted))
  1828.         goto failed;
  1829.  
  1830.     if (bSACLPresent)
  1831.     {
  1832.         if (pSACL)
  1833.         {
  1834.             // allocate new SACL.
  1835.             m_pSACL = (PACL) malloc(pSACL->AclSize);
  1836.             if (!m_pSACL)
  1837.                 goto failed;
  1838.  
  1839.             // initialize the SACL
  1840.             if (!InitializeAcl(m_pSACL, pSACL->AclSize, ACL_REVISION))
  1841.                 goto failed;
  1842.  
  1843.             // copy the ACES
  1844.             for (int i = 0; i < pSACL->AceCount; i++)
  1845.             {
  1846.                 if (!GetAce(pSACL, i, (void **)&pACE))
  1847.                     goto failed;
  1848.  
  1849.                 if (!AddAccessAllowedAce(m_pSACL, ACL_REVISION, pACE->Mask, (PSID)&(pACE->SidStart)))
  1850.                     goto failed;
  1851.             }
  1852.  
  1853.             if (!IsValidAcl(m_pSACL))
  1854.                 goto failed;
  1855.         }
  1856.  
  1857.         // set the SACL
  1858.         if (!SetSecurityDescriptorSacl(m_pSD, m_pSACL ? TRUE : FALSE, m_pSACL, bDefaulted))
  1859.             goto failed;
  1860.     }
  1861.  
  1862.     if (!GetSecurityDescriptorOwner(m_pSD, &pUserSid, &bDefaulted))
  1863.         goto failed;
  1864.  
  1865.     if (FAILED(SetOwner(pUserSid, bDefaulted)))
  1866.         goto failed;
  1867.  
  1868.     if (!GetSecurityDescriptorGroup(m_pSD, &pGroupSid, &bDefaulted))
  1869.         goto failed;
  1870.  
  1871.     if (FAILED(SetGroup(pGroupSid, bDefaulted)))
  1872.         goto failed;
  1873.  
  1874.     if (!IsValidSecurityDescriptor(m_pSD))
  1875.         goto failed;
  1876.  
  1877.     return hr;
  1878.  
  1879. failed:
  1880.     if (m_pDACL)
  1881.         free(m_pDACL);
  1882.     if (m_pSD)
  1883.         free(m_pSD);
  1884.     return E_UNEXPECTED;
  1885. }
  1886.  
  1887. HRESULT CSecurityDescriptor::AttachObject(HANDLE hObject)
  1888. {
  1889.     HRESULT hr;
  1890.     DWORD dwSize = 0;
  1891.     PSECURITY_DESCRIPTOR pSD = NULL;
  1892.  
  1893.     GetKernelObjectSecurity(hObject, OWNER_SECURITY_INFORMATION | GROUP_SECURITY_INFORMATION |
  1894.         DACL_SECURITY_INFORMATION, pSD, 0, &dwSize);
  1895.  
  1896.     hr = GetLastError();
  1897.     if (hr != ERROR_INSUFFICIENT_BUFFER)
  1898.         return HRESULT_FROM_WIN32(hr);
  1899.  
  1900.     pSD = (PSECURITY_DESCRIPTOR) malloc(dwSize);
  1901.  
  1902.     if (!GetKernelObjectSecurity(hObject, OWNER_SECURITY_INFORMATION | GROUP_SECURITY_INFORMATION |
  1903.         DACL_SECURITY_INFORMATION, pSD, dwSize, &dwSize))
  1904.     {
  1905.         hr = HRESULT_FROM_WIN32(GetLastError());
  1906.         free(pSD);
  1907.         return hr;
  1908.     }
  1909.  
  1910.     hr = Attach(pSD);
  1911.     free(pSD);
  1912.     return hr;
  1913. }
  1914.  
  1915.  
  1916. HRESULT CSecurityDescriptor::CopyACL(PACL pDest, PACL pSrc)
  1917. {
  1918.     ACL_SIZE_INFORMATION aclSizeInfo;
  1919.     LPVOID pAce;
  1920.     ACE_HEADER *aceHeader;
  1921.  
  1922.     if (pSrc == NULL)
  1923.         return S_OK;
  1924.  
  1925.     if (!GetAclInformation(pSrc, (LPVOID) &aclSizeInfo, sizeof(ACL_SIZE_INFORMATION), AclSizeInformation))
  1926.         return HRESULT_FROM_WIN32(GetLastError());
  1927.  
  1928.     // Copy all of the ACEs to the new ACL
  1929.     for (UINT i = 0; i < aclSizeInfo.AceCount; i++)
  1930.     {
  1931.         if (!GetAce(pSrc, i, &pAce))
  1932.             return HRESULT_FROM_WIN32(GetLastError());
  1933.  
  1934.         aceHeader = (ACE_HEADER *) pAce;
  1935.  
  1936.         if (!AddAce(pDest, ACL_REVISION, 0xffffffff, pAce, aceHeader->AceSize))
  1937.             return HRESULT_FROM_WIN32(GetLastError());
  1938.     }
  1939.  
  1940.     return S_OK;
  1941. }
  1942.  
  1943. HRESULT CSecurityDescriptor::AddAccessDeniedACEToACL(PACL *ppAcl, LPCTSTR pszPrincipal, DWORD dwAccessMask)
  1944. {
  1945.     ACL_SIZE_INFORMATION aclSizeInfo;
  1946.     int aclSize;
  1947.     DWORD returnValue;
  1948.     PSID principalSID;
  1949.     PACL oldACL, newACL;
  1950.  
  1951.     oldACL = *ppAcl;
  1952.  
  1953.     returnValue = GetPrincipalSID(pszPrincipal, &principalSID);
  1954.     if (FAILED(returnValue))
  1955.         return returnValue;
  1956.  
  1957.     aclSizeInfo.AclBytesInUse = 0;
  1958.     if (*ppAcl != NULL)
  1959.         GetAclInformation(oldACL, (LPVOID) &aclSizeInfo, sizeof(ACL_SIZE_INFORMATION), AclSizeInformation);
  1960.  
  1961.     aclSize = aclSizeInfo.AclBytesInUse + sizeof(ACL) + sizeof(ACCESS_DENIED_ACE) + GetLengthSid(principalSID) - sizeof(DWORD);
  1962.  
  1963.     ATLTRY(newACL = (PACL) new BYTE[aclSize]);
  1964.  
  1965.     if (!InitializeAcl(newACL, aclSize, ACL_REVISION))
  1966.     {
  1967.         free(principalSID);
  1968.         return HRESULT_FROM_WIN32(GetLastError());
  1969.     }
  1970.  
  1971.     if (!AddAccessDeniedAce(newACL, ACL_REVISION2, dwAccessMask, principalSID))
  1972.     {
  1973.         free(principalSID);
  1974.         return HRESULT_FROM_WIN32(GetLastError());
  1975.     }
  1976.  
  1977.     returnValue = CopyACL(newACL, oldACL);
  1978.     if (FAILED(returnValue))
  1979.     {
  1980.         free(principalSID);
  1981.         return returnValue;
  1982.     }
  1983.  
  1984.     *ppAcl = newACL;
  1985.  
  1986.     if (oldACL != NULL)
  1987.         free(oldACL);
  1988.     free(principalSID);
  1989.     return S_OK;
  1990. }
  1991.  
  1992.  
  1993. HRESULT CSecurityDescriptor::AddAccessAllowedACEToACL(PACL *ppAcl, LPCTSTR pszPrincipal, DWORD dwAccessMask)
  1994. {
  1995.     ACL_SIZE_INFORMATION aclSizeInfo;
  1996.     int aclSize;
  1997.     DWORD returnValue;
  1998.     PSID principalSID;
  1999.     PACL oldACL, newACL;
  2000.  
  2001.     oldACL = *ppAcl;
  2002.  
  2003.     returnValue = GetPrincipalSID(pszPrincipal, &principalSID);
  2004.     if (FAILED(returnValue))
  2005.         return returnValue;
  2006.  
  2007.     aclSizeInfo.AclBytesInUse = 0;
  2008.     if (*ppAcl != NULL)
  2009.         GetAclInformation(oldACL, (LPVOID) &aclSizeInfo, (DWORD) sizeof(ACL_SIZE_INFORMATION), AclSizeInformation);
  2010.  
  2011.     aclSize = aclSizeInfo.AclBytesInUse + sizeof(ACL) + sizeof(ACCESS_ALLOWED_ACE) + GetLengthSid(principalSID) - sizeof(DWORD);
  2012.  
  2013.     ATLTRY(newACL = (PACL) new BYTE[aclSize]);
  2014.  
  2015.     if (!InitializeAcl(newACL, aclSize, ACL_REVISION))
  2016.     {
  2017.         free(principalSID);
  2018.         return HRESULT_FROM_WIN32(GetLastError());
  2019.     }
  2020.  
  2021.     returnValue = CopyACL(newACL, oldACL);
  2022.     if (FAILED(returnValue))
  2023.     {
  2024.         free(principalSID);
  2025.         return returnValue;
  2026.     }
  2027.  
  2028.     if (!AddAccessAllowedAce(newACL, ACL_REVISION2, dwAccessMask, principalSID))
  2029.     {
  2030.         free(principalSID);
  2031.         return HRESULT_FROM_WIN32(GetLastError());
  2032.     }
  2033.  
  2034.     *ppAcl = newACL;
  2035.  
  2036.     if (oldACL != NULL)
  2037.         free(oldACL);
  2038.     free(principalSID);
  2039.     return S_OK;
  2040. }
  2041.  
  2042.  
  2043. HRESULT CSecurityDescriptor::RemovePrincipalFromACL(PACL pAcl, LPCTSTR pszPrincipal)
  2044. {
  2045.     ACL_SIZE_INFORMATION aclSizeInfo;
  2046.     ULONG i;
  2047.     LPVOID ace;
  2048.     ACCESS_ALLOWED_ACE *accessAllowedAce;
  2049.     ACCESS_DENIED_ACE *accessDeniedAce;
  2050.     SYSTEM_AUDIT_ACE *systemAuditAce;
  2051.     PSID principalSID;
  2052.     DWORD returnValue;
  2053.     ACE_HEADER *aceHeader;
  2054.  
  2055.     returnValue = GetPrincipalSID(pszPrincipal, &principalSID);
  2056.     if (FAILED(returnValue))
  2057.         return returnValue;
  2058.  
  2059.     GetAclInformation(pAcl, (LPVOID) &aclSizeInfo, (DWORD) sizeof(ACL_SIZE_INFORMATION), AclSizeInformation);
  2060.  
  2061.     for (i = 0; i < aclSizeInfo.AceCount; i++)
  2062.     {
  2063.         if (!GetAce(pAcl, i, &ace))
  2064.         {
  2065.             free(principalSID);
  2066.             return HRESULT_FROM_WIN32(GetLastError());
  2067.         }
  2068.  
  2069.         aceHeader = (ACE_HEADER *) ace;
  2070.  
  2071.         if (aceHeader->AceType == ACCESS_ALLOWED_ACE_TYPE)
  2072.         {
  2073.             accessAllowedAce = (ACCESS_ALLOWED_ACE *) ace;
  2074.  
  2075.             if (EqualSid(principalSID, (PSID) &accessAllowedAce->SidStart))
  2076.             {
  2077.                 DeleteAce(pAcl, i);
  2078.                 free(principalSID);
  2079.                 return S_OK;
  2080.             }
  2081.         } else
  2082.  
  2083.         if (aceHeader->AceType == ACCESS_DENIED_ACE_TYPE)
  2084.         {
  2085.             accessDeniedAce = (ACCESS_DENIED_ACE *) ace;
  2086.  
  2087.             if (EqualSid(principalSID, (PSID) &accessDeniedAce->SidStart))
  2088.             {
  2089.                 DeleteAce(pAcl, i);
  2090.                 free(principalSID);
  2091.                 return S_OK;
  2092.             }
  2093.         } else
  2094.  
  2095.         if (aceHeader->AceType == SYSTEM_AUDIT_ACE_TYPE)
  2096.         {
  2097.             systemAuditAce = (SYSTEM_AUDIT_ACE *) ace;
  2098.  
  2099.             if (EqualSid(principalSID, (PSID) &systemAuditAce->SidStart))
  2100.             {
  2101.                 DeleteAce(pAcl, i);
  2102.                 free(principalSID);
  2103.                 return S_OK;
  2104.             }
  2105.         }
  2106.     }
  2107.     free(principalSID);
  2108.     return S_OK;
  2109. }
  2110.  
  2111.  
  2112. HRESULT CSecurityDescriptor::SetPrivilege(LPCTSTR privilege, BOOL bEnable, HANDLE hToken)
  2113. {
  2114.     HRESULT hr;
  2115.     TOKEN_PRIVILEGES tpPrevious;
  2116.     TOKEN_PRIVILEGES tp;
  2117.     DWORD cbPrevious = sizeof(TOKEN_PRIVILEGES);
  2118.     LUID luid;
  2119.  
  2120.     // if no token specified open process token
  2121.     if (hToken == 0)
  2122.     {
  2123.         if (!OpenProcessToken(GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &hToken))
  2124.         {
  2125.             hr = HRESULT_FROM_WIN32(GetLastError());
  2126.             _ASSERTE(FALSE);
  2127.             return hr;
  2128.         }
  2129.     }
  2130.  
  2131.     if (!LookupPrivilegeValue(NULL, privilege, &luid ))
  2132.     {
  2133.         hr = HRESULT_FROM_WIN32(GetLastError());
  2134.         _ASSERTE(FALSE);
  2135.         return hr;
  2136.     }
  2137.  
  2138.     tp.PrivilegeCount = 1;
  2139.     tp.Privileges[0].Luid = luid;
  2140.     tp.Privileges[0].Attributes = 0;
  2141.  
  2142.     if (!AdjustTokenPrivileges(hToken, FALSE, &tp, sizeof(TOKEN_PRIVILEGES), &tpPrevious, &cbPrevious))
  2143.     {
  2144.         hr = HRESULT_FROM_WIN32(GetLastError());
  2145.         _ASSERTE(FALSE);
  2146.         return hr;
  2147.     }
  2148.  
  2149.     tpPrevious.PrivilegeCount = 1;
  2150.     tpPrevious.Privileges[0].Luid = luid;
  2151.  
  2152.     if (bEnable)
  2153.         tpPrevious.Privileges[0].Attributes |= (SE_PRIVILEGE_ENABLED);
  2154.     else
  2155.         tpPrevious.Privileges[0].Attributes ^= (SE_PRIVILEGE_ENABLED & tpPrevious.Privileges[0].Attributes);
  2156.  
  2157.     if (!AdjustTokenPrivileges(hToken, FALSE, &tpPrevious, cbPrevious, NULL, NULL))
  2158.     {
  2159.         hr = HRESULT_FROM_WIN32(GetLastError());
  2160.         _ASSERTE(FALSE);
  2161.         return hr;
  2162.     }
  2163.     return S_OK;
  2164. }
  2165.  
  2166. #endif //_ATL_NO_SECURITY
  2167. #endif //__ATLCOM_H__
  2168.  
  2169. #ifdef _DEBUG
  2170.  
  2171. void _cdecl AtlTrace(LPCTSTR lpszFormat, ...)
  2172. {
  2173.     va_list args;
  2174.     va_start(args, lpszFormat);
  2175.  
  2176.     int nBuf;
  2177.     TCHAR szBuffer[512];
  2178.  
  2179.     nBuf = _vstprintf(szBuffer, lpszFormat, args);
  2180.     _ASSERTE(nBuf < sizeof(szBuffer));
  2181.  
  2182.     OutputDebugString(szBuffer);
  2183.     va_end(args);
  2184. }
  2185. #endif
  2186.  
  2187. #ifndef ATL_NO_NAMESPACE
  2188. }; //namespace ATL
  2189. #endif
  2190.  
  2191. ///////////////////////////////////////////////////////////////////////////////
  2192. //All Global stuff goes below this line
  2193. ///////////////////////////////////////////////////////////////////////////////
  2194.  
  2195. /////////////////////////////////////////////////////////////////////////////
  2196. // Minimize CRT
  2197. // Specify DllMain as EntryPoint
  2198. // Turn off exception handling
  2199. // Define _ATL_MIN_CRT
  2200.  
  2201. #ifdef _ATL_MIN_CRT
  2202. /////////////////////////////////////////////////////////////////////////////
  2203. // Startup Code
  2204.  
  2205. #if defined(_WINDLL) || defined(_USRDLL)
  2206.  
  2207. // Declare DllMain
  2208. extern "C" BOOL WINAPI DllMain(HANDLE hDllHandle, DWORD dwReason, LPVOID lpReserved);
  2209.  
  2210. extern "C" BOOL WINAPI _DllMainCRTStartup(HANDLE hDllHandle, DWORD dwReason, LPVOID lpReserved)
  2211. {
  2212.     return DllMain(hDllHandle, dwReason, lpReserved);
  2213. }
  2214.  
  2215. #else
  2216.  
  2217. // wWinMain is not defined in winbase.h.
  2218. extern "C" int WINAPI wWinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPWSTR lpCmdLine, int nShowCmd);
  2219.  
  2220. #define SPACECHAR   _T(' ')
  2221. #define DQUOTECHAR  _T('\"')
  2222.  
  2223.  
  2224. #ifdef _UNICODE
  2225. extern "C" void wWinMainCRTStartup()
  2226. #else // _UNICODE
  2227. extern "C" void WinMainCRTStartup()
  2228. #endif // _UNICODE
  2229. {
  2230.     LPTSTR lpszCommandLine = ::GetCommandLine();
  2231.     if(lpszCommandLine == NULL)
  2232.         ::ExitProcess((UINT)-1);
  2233.  
  2234.     // Skip past program name (first token in command line).
  2235.     // Check for and handle quoted program name.
  2236.     if(*lpszCommandLine == DQUOTECHAR)
  2237.     {
  2238.         // Scan, and skip over, subsequent characters until
  2239.         // another double-quote or a null is encountered.
  2240.         do
  2241.         {
  2242.             lpszCommandLine = ::CharNext(lpszCommandLine);
  2243.         }
  2244.         while((*lpszCommandLine != DQUOTECHAR) && (*lpszCommandLine != _T('\0')));
  2245.  
  2246.         // If we stopped on a double-quote (usual case), skip over it.
  2247.         if(*lpszCommandLine == DQUOTECHAR)
  2248.             lpszCommandLine = ::CharNext(lpszCommandLine);
  2249.     }
  2250.     else
  2251.     {
  2252.         while(*lpszCommandLine > SPACECHAR)
  2253.             lpszCommandLine = ::CharNext(lpszCommandLine);
  2254.     }
  2255.  
  2256.     // Skip past any white space preceeding the second token.
  2257.     while(*lpszCommandLine && (*lpszCommandLine <= SPACECHAR))
  2258.         lpszCommandLine = ::CharNext(lpszCommandLine);
  2259.  
  2260.     STARTUPINFO StartupInfo;
  2261.     StartupInfo.dwFlags = 0;
  2262.     ::GetStartupInfo(&StartupInfo);
  2263.  
  2264.     int nRet = _tWinMain(::GetModuleHandle(NULL), NULL, lpszCommandLine,
  2265.         (StartupInfo.dwFlags & STARTF_USESHOWWINDOW) ?
  2266.         StartupInfo.wShowWindow : SW_SHOWDEFAULT);
  2267.  
  2268.     ::ExitProcess((UINT)nRet);
  2269. }
  2270.  
  2271. #endif // defined(_WINDLL) | defined(_USRDLL)
  2272.  
  2273. /////////////////////////////////////////////////////////////////////////////
  2274. // Heap Allocation
  2275.  
  2276. #ifndef _DEBUG
  2277.  
  2278. #ifndef _MERGE_PROXYSTUB
  2279. //rpcproxy.h does the same thing as this
  2280. int __cdecl _purecall()
  2281. {
  2282.     DebugBreak();
  2283.     return 0;
  2284. }
  2285. #endif
  2286.  
  2287. extern "C" const int _fltused = 0;
  2288.  
  2289. void* __cdecl malloc(size_t n)
  2290. {
  2291.     if (_Module.m_hHeap == NULL)
  2292.     {
  2293.         _Module.m_hHeap = HeapCreate(0, 0, 0);
  2294.         if (_Module.m_hHeap == NULL)
  2295.             return NULL;
  2296.     }
  2297.     _ASSERTE(_Module.m_hHeap != NULL);
  2298.  
  2299. #ifdef _MALLOC_ZEROINIT
  2300.     return HeapAlloc(_Module.m_hHeap, HEAP_ZERO_MEMORY, n);
  2301. #else
  2302.     return HeapAlloc(_Module.m_hHeap, 0, n);
  2303. #endif
  2304. }
  2305.  
  2306. void* __cdecl calloc(size_t n, size_t s)
  2307. {
  2308.     return malloc(n * s);
  2309. }
  2310.  
  2311. void* __cdecl realloc(void* p, size_t n)
  2312. {
  2313.     _ASSERTE(_Module.m_hHeap != NULL);
  2314. #ifdef _MALLOC_ZEROINIT
  2315.     return (p == NULL) ? malloc(n) : HeapReAlloc(_Module.m_hHeap, HEAP_ZERO_MEMORY, p, n);
  2316. #else
  2317.     return (p == NULL) ? malloc(n) : HeapReAlloc(_Module.m_hHeap, 0, p, n);
  2318. #endif
  2319. }
  2320.  
  2321. void __cdecl free(void* p)
  2322. {
  2323.     _ASSERTE(_Module.m_hHeap != NULL);
  2324.     HeapFree(_Module.m_hHeap, 0, p);
  2325. }
  2326.  
  2327. void* __cdecl operator new(size_t n)
  2328. {
  2329.     return malloc(n);
  2330. }
  2331.  
  2332. void __cdecl operator delete(void* p)
  2333. {
  2334.     free(p);
  2335. }
  2336.  
  2337. #endif  //_DEBUG
  2338.  
  2339. #endif //_ATL_MIN_CRT
  2340.  
  2341. #ifndef _ATL_DLL
  2342.  
  2343. #ifndef ATL_NO_NAMESPACE
  2344. #ifndef _ATL_DLL_IMPL
  2345. namespace ATL
  2346. {
  2347. #endif
  2348. #endif
  2349.  
  2350. /////////////////////////////////////////////////////////////////////////////
  2351. // statics
  2352.  
  2353. static UINT WINAPI AtlGetDirLen(LPCOLESTR lpszPathName)
  2354. {
  2355.     _ASSERTE(lpszPathName != NULL);
  2356.  
  2357.     // always capture the complete file name including extension (if present)
  2358.     LPCOLESTR lpszTemp = lpszPathName;
  2359.     for (LPCOLESTR lpsz = lpszPathName; *lpsz != NULL; )
  2360.     {
  2361.         LPCOLESTR lp = CharNextO(lpsz);
  2362.         // remember last directory/drive separator
  2363.         if (*lpsz == OLESTR('\\') || *lpsz == OLESTR('/') || *lpsz == OLESTR(':'))
  2364.             lpszTemp = lp;
  2365.         lpsz = lp;
  2366.     }
  2367.  
  2368.     return lpszTemp-lpszPathName;
  2369. }
  2370.  
  2371. /////////////////////////////////////////////////////////////////////////////
  2372. // QI support
  2373.  
  2374. ATLAPI AtlInternalQueryInterface(void* pThis,
  2375.     const _ATL_INTMAP_ENTRY* pEntries, REFIID iid, void** ppvObject)
  2376. {
  2377.     _ASSERTE(pThis != NULL);
  2378.     // First entry in the com map should be a simple map entry
  2379.     _ASSERTE(pEntries->pFunc == _ATL_SIMPLEMAPENTRY);
  2380.     if (ppvObject == NULL)
  2381.         return E_POINTER;
  2382.     *ppvObject = NULL;
  2383.     if (InlineIsEqualUnknown(iid)) // use first interface
  2384.     {
  2385.             IUnknown* pUnk = (IUnknown*)((int)pThis+pEntries->dw);
  2386.             pUnk->AddRef();
  2387.             *ppvObject = pUnk;
  2388.             return S_OK;
  2389.     }
  2390.     while (pEntries->pFunc != NULL)
  2391.     {
  2392.         BOOL bBlind = (pEntries->piid == NULL);
  2393.         if (bBlind || InlineIsEqualGUID(*(pEntries->piid), iid))
  2394.         {
  2395.             if (pEntries->pFunc == _ATL_SIMPLEMAPENTRY) //offset
  2396.             {
  2397.                 _ASSERTE(!bBlind);
  2398.                 IUnknown* pUnk = (IUnknown*)((int)pThis+pEntries->dw);
  2399.                 pUnk->AddRef();
  2400.                 *ppvObject = pUnk;
  2401.                 return S_OK;
  2402.             }
  2403.             else //actual function call
  2404.             {
  2405.                 HRESULT hRes = pEntries->pFunc(pThis,
  2406.                     iid, ppvObject, pEntries->dw);
  2407.                 if (hRes == S_OK || (!bBlind && FAILED(hRes)))
  2408.                     return hRes;
  2409.             }
  2410.         }
  2411.         pEntries++;
  2412.     }
  2413.     return E_NOINTERFACE;
  2414. }
  2415.  
  2416. /////////////////////////////////////////////////////////////////////////////
  2417. // Smart Pointer helpers
  2418.  
  2419. ATLAPI_(IUnknown*) AtlComPtrAssign(IUnknown** pp, IUnknown* lp)
  2420. {
  2421.     if (lp != NULL)
  2422.         lp->AddRef();
  2423.     if (*pp)
  2424.         (*pp)->Release();
  2425.     *pp = lp;
  2426.     return lp;
  2427. }
  2428.  
  2429. ATLAPI_(IUnknown*) AtlComQIPtrAssign(IUnknown** pp, IUnknown* lp, REFIID riid)
  2430. {
  2431.     IUnknown* pTemp = *pp;
  2432.     lp->QueryInterface(riid, (void**)pp);
  2433.     if (pTemp)
  2434.         pTemp->Release();
  2435.     return *pp;
  2436. }
  2437.  
  2438. /////////////////////////////////////////////////////////////////////////////
  2439. // Inproc Marshaling helpers
  2440.  
  2441. ATLAPI AtlFreeMarshalStream(IStream* pStream)
  2442. {
  2443.     if (pStream != NULL)
  2444.     {
  2445.         CoReleaseMarshalData(pStream);
  2446.         pStream->Release();
  2447.     }
  2448.     return S_OK;
  2449. }
  2450.  
  2451. ATLAPI AtlMarshalPtrInProc(IUnknown* pUnk, const IID& iid, IStream** ppStream)
  2452. {
  2453.     HRESULT hRes = CreateStreamOnHGlobal(NULL, TRUE, ppStream);
  2454.     if (SUCCEEDED(hRes))
  2455.     {
  2456.         hRes = CoMarshalInterface(*ppStream, iid,
  2457.             pUnk, MSHCTX_INPROC, NULL, MSHLFLAGS_TABLESTRONG);
  2458.         if (FAILED(hRes))
  2459.         {
  2460.             (*ppStream)->Release();
  2461.             *ppStream = NULL;
  2462.         }
  2463.     }
  2464.     return hRes;
  2465. }
  2466.  
  2467. ATLAPI AtlUnmarshalPtr(IStream* pStream, const IID& iid, IUnknown** ppUnk)
  2468. {
  2469.     *ppUnk = NULL;
  2470.     HRESULT hRes = E_INVALIDARG;
  2471.     if (pStream != NULL)
  2472.     {
  2473.         LARGE_INTEGER l;
  2474.         l.QuadPart = 0;
  2475.         pStream->Seek(l, STREAM_SEEK_SET, NULL);
  2476.         hRes = CoUnmarshalInterface(pStream, iid, (void**)ppUnk);
  2477.     }
  2478.     return hRes;
  2479. }
  2480.  
  2481. ATLAPI_(BOOL) AtlWaitWithMessageLoop(HANDLE hEvent)
  2482. {
  2483.     DWORD dwRet;
  2484.     MSG msg;
  2485.  
  2486.     while(1)
  2487.     {
  2488.         dwRet = MsgWaitForMultipleObjects(1, &hEvent, FALSE, INFINITE, QS_ALLINPUT);
  2489.  
  2490.         if (dwRet == WAIT_OBJECT_0)
  2491.             return TRUE;    // The event was signaled
  2492.  
  2493.         if (dwRet != WAIT_OBJECT_0 + 1)
  2494.             break;          // Something else happened
  2495.  
  2496.         // There is one or more window message available. Dispatch them
  2497.         while(PeekMessage(&msg,NULL,NULL,NULL,PM_REMOVE))
  2498.         {
  2499.             TranslateMessage(&msg);
  2500.             DispatchMessage(&msg);
  2501.             if (WaitForSingleObject(hEvent, 0) == WAIT_OBJECT_0)
  2502.                 return TRUE; // Event is now signaled.
  2503.         }
  2504.     }
  2505.     return FALSE;
  2506. }
  2507.  
  2508. /////////////////////////////////////////////////////////////////////////////
  2509. // Connection Point Helpers
  2510.  
  2511. ATLAPI AtlAdvise(IUnknown* pUnkCP, IUnknown* pUnk, const IID& iid, LPDWORD pdw)
  2512. {
  2513.     CComPtr<IConnectionPointContainer> pCPC;
  2514.     CComPtr<IConnectionPoint> pCP;
  2515.     HRESULT hRes = pUnkCP->QueryInterface(IID_IConnectionPointContainer, (void**)&pCPC);
  2516.     if (SUCCEEDED(hRes))
  2517.         hRes = pCPC->FindConnectionPoint(iid, &pCP);
  2518.     if (SUCCEEDED(hRes))
  2519.         hRes = pCP->Advise(pUnk, pdw);
  2520.     return hRes;
  2521. }
  2522.  
  2523. ATLAPI AtlUnadvise(IUnknown* pUnkCP, const IID& iid, DWORD dw)
  2524. {
  2525.     CComPtr<IConnectionPointContainer> pCPC;
  2526.     CComPtr<IConnectionPoint> pCP;
  2527.     HRESULT hRes = pUnkCP->QueryInterface(IID_IConnectionPointContainer, (void**)&pCPC);
  2528.     if (SUCCEEDED(hRes))
  2529.         hRes = pCPC->FindConnectionPoint(iid, &pCP);
  2530.     if (SUCCEEDED(hRes))
  2531.         hRes = pCP->Unadvise(dw);
  2532.     return hRes;
  2533. }
  2534.  
  2535. /////////////////////////////////////////////////////////////////////////////
  2536. // IDispatch Error handling
  2537.  
  2538. ATLAPI AtlSetErrorInfo(const CLSID& clsid, LPCOLESTR lpszDesc, DWORD dwHelpID,
  2539.     LPCOLESTR lpszHelpFile, const IID& iid, HRESULT hRes, HINSTANCE hInst)
  2540. {
  2541.     USES_CONVERSION;
  2542.     TCHAR szDesc[1024];
  2543.     szDesc[0] = NULL;
  2544.     // For a valid HRESULT the id should be in the range [0x0200, 0xffff]
  2545.     if (HIWORD(lpszDesc) == 0) //id
  2546.     {
  2547.         UINT nID = LOWORD((DWORD)lpszDesc);
  2548.         _ASSERTE((nID >= 0x0200 && nID <= 0xffff) || hRes != 0);
  2549.         if (LoadString(hInst, nID, szDesc, 1024) == 0)
  2550.         {
  2551.             _ASSERTE(FALSE);
  2552.             lstrcpy(szDesc, _T("Unknown Error"));
  2553.         }
  2554.         lpszDesc = T2OLE(szDesc);
  2555.         if (hRes == 0)
  2556.             hRes = MAKE_HRESULT(3, FACILITY_ITF, nID);
  2557.     }
  2558.  
  2559.     CComPtr<ICreateErrorInfo> pICEI;
  2560.     if (SUCCEEDED(CreateErrorInfo(&pICEI)))
  2561.     {
  2562.         CComPtr<IErrorInfo> pErrorInfo;
  2563.         pICEI->SetGUID(iid);
  2564.         LPOLESTR lpsz;
  2565.         ProgIDFromCLSID(clsid, &lpsz);
  2566.         if (lpsz != NULL)
  2567.             pICEI->SetSource(lpsz);
  2568.         if (dwHelpID != 0 && lpszHelpFile != NULL)
  2569.         {
  2570.             pICEI->SetHelpContext(dwHelpID);
  2571.             pICEI->SetHelpFile(const_cast<LPOLESTR>(lpszHelpFile));
  2572.         }
  2573.         CoTaskMemFree(lpsz);
  2574.         pICEI->SetDescription((LPOLESTR)lpszDesc);
  2575.         if (SUCCEEDED(pICEI->QueryInterface(IID_IErrorInfo, (void**)&pErrorInfo)))
  2576.             SetErrorInfo(0, pErrorInfo);
  2577.     }
  2578. //#ifdef _DEBUG
  2579. //  USES_CONVERSION;
  2580. //  ATLTRACE(_T("AtlReportError: Description=\"%s\" returning %x\n"), OLE2CT(lpszDesc), hRes);
  2581. //#endif
  2582.     return (hRes == 0) ? DISP_E_EXCEPTION : hRes;
  2583. }
  2584.  
  2585. /////////////////////////////////////////////////////////////////////////////
  2586. // Module
  2587.  
  2588. //Although these functions are big, they are only used once in a module
  2589. //so we should make them inline.
  2590.  
  2591. ATLAPI AtlModuleInit(_ATL_MODULE* pM, _ATL_OBJMAP_ENTRY* p, HINSTANCE h)
  2592. {
  2593.     _ASSERTE(pM != NULL);
  2594.     if (pM == NULL)
  2595.         return E_INVALIDARG;
  2596.     if (pM->cbSize < sizeof(_ATL_MODULE))
  2597.         return E_INVALIDARG;
  2598.     pM->m_pObjMap = p;
  2599.     pM->m_hInst = pM->m_hInstTypeLib = pM->m_hInstResource = h;
  2600.     pM->m_nLockCnt=0L;
  2601.     pM->m_hHeap = NULL;
  2602.     InitializeCriticalSection(&pM->m_csTypeInfoHolder);
  2603.     InitializeCriticalSection(&pM->m_csWindowCreate);
  2604.     InitializeCriticalSection(&pM->m_csObjMap);
  2605.     return S_OK;
  2606. }
  2607.  
  2608. ATLAPI AtlModuleRegisterClassObjects(_ATL_MODULE* pM, DWORD dwClsContext, DWORD dwFlags)
  2609. {
  2610.     _ASSERTE(pM != NULL);
  2611.     if (pM == NULL)
  2612.         return E_INVALIDARG;
  2613.     _ASSERTE(pM->m_pObjMap != NULL);
  2614.     _ATL_OBJMAP_ENTRY* pEntry = pM->m_pObjMap;
  2615.     HRESULT hRes = S_OK;
  2616.     while (pEntry->pclsid != NULL && hRes == S_OK)
  2617.     {
  2618.         hRes = pEntry->RegisterClassObject(dwClsContext, dwFlags);
  2619.         pEntry++;
  2620.     }
  2621.     return hRes;
  2622. }
  2623.  
  2624. ATLAPI AtlModuleRevokeClassObjects(_ATL_MODULE* pM)
  2625. {
  2626.     _ASSERTE(pM != NULL);
  2627.     if (pM == NULL)
  2628.         return E_INVALIDARG;
  2629.     _ASSERTE(pM->m_pObjMap != NULL);
  2630.     _ATL_OBJMAP_ENTRY* pEntry = pM->m_pObjMap;
  2631.     HRESULT hRes = S_OK;
  2632.     while (pEntry->pclsid != NULL && hRes == S_OK)
  2633.     {
  2634.         hRes = pEntry->RevokeClassObject();
  2635.         pEntry++;
  2636.     }
  2637.     return hRes;
  2638. }
  2639.  
  2640. ATLAPI AtlModuleGetClassObject(_ATL_MODULE* pM, REFCLSID rclsid, REFIID riid, LPVOID* ppv)
  2641. {
  2642.     _ASSERTE(pM != NULL);
  2643.     if (pM == NULL)
  2644.         return E_INVALIDARG;
  2645.     _ASSERTE(pM->m_pObjMap != NULL);
  2646.     _ATL_OBJMAP_ENTRY* pEntry = pM->m_pObjMap;
  2647.     HRESULT hRes = S_OK;
  2648.     if (ppv == NULL)
  2649.         return E_POINTER;
  2650.     while (pEntry->pclsid != NULL)
  2651.     {
  2652.         if (InlineIsEqualGUID(rclsid, *pEntry->pclsid))
  2653.         {
  2654.             if (pEntry->pCF == NULL)
  2655.             {
  2656.                 EnterCriticalSection(&pM->m_csObjMap);
  2657.                 if (pEntry->pCF == NULL)
  2658.                     hRes = pEntry->pfnGetClassObject(pEntry->pfnCreateInstance, IID_IUnknown, (LPVOID*)&pEntry->pCF);
  2659.                 LeaveCriticalSection(&pM->m_csObjMap);
  2660.             }
  2661.             if (pEntry->pCF != NULL)
  2662.                 hRes = pEntry->pCF->QueryInterface(riid, ppv);
  2663.             break;
  2664.         }
  2665.         pEntry++;
  2666.     }
  2667.     if (*ppv == NULL && hRes == S_OK)
  2668.         hRes = CLASS_E_CLASSNOTAVAILABLE;
  2669.     return hRes;
  2670. }
  2671.  
  2672. ATLAPI AtlModuleTerm(_ATL_MODULE* pM)
  2673. {
  2674.     _ASSERTE(pM != NULL);
  2675.     if (pM == NULL)
  2676.         return E_INVALIDARG;
  2677.     _ASSERTE(pM->m_hInst != NULL);
  2678.     if (pM->m_pObjMap != NULL)
  2679.     {
  2680.         _ATL_OBJMAP_ENTRY* pEntry = pM->m_pObjMap;
  2681.         while (pEntry->pclsid != NULL)
  2682.         {
  2683.             if (pEntry->pCF != NULL)
  2684.                 pEntry->pCF->Release();
  2685.             pEntry->pCF = NULL;
  2686.             pEntry++;
  2687.         }
  2688.     }
  2689.     DeleteCriticalSection(&pM->m_csTypeInfoHolder);
  2690.     DeleteCriticalSection(&pM->m_csWindowCreate);
  2691.     DeleteCriticalSection(&pM->m_csObjMap);
  2692.     if (pM->m_hHeap != NULL)
  2693.         HeapDestroy(pM->m_hHeap);
  2694.     return S_OK;
  2695. }
  2696.  
  2697. ATLAPI AtlModuleRegisterServer(_ATL_MODULE* pM, BOOL bRegTypeLib, const CLSID* pCLSID)
  2698. {
  2699.     _ASSERTE(pM != NULL);
  2700.     if (pM == NULL)
  2701.         return E_INVALIDARG;
  2702.     _ASSERTE(pM->m_hInst != NULL);
  2703.     _ASSERTE(pM->m_pObjMap != NULL);
  2704.     _ATL_OBJMAP_ENTRY* pEntry = pM->m_pObjMap;
  2705.     HRESULT hRes = S_OK;
  2706.     for (;pEntry->pclsid != NULL; pEntry++)
  2707.     {
  2708.         if (pCLSID == NULL)
  2709.         {
  2710.             if (pEntry->pfnGetObjectDescription() != NULL)
  2711.                 continue;
  2712.         }
  2713.         else
  2714.         {
  2715.             if (!IsEqualGUID(*pCLSID, *pEntry->pclsid))
  2716.                 continue;
  2717.         }
  2718.         hRes = pEntry->pfnUpdateRegistry(TRUE);
  2719.         if (FAILED(hRes))
  2720.             break;
  2721.     }
  2722.     if (SUCCEEDED(hRes) && bRegTypeLib)
  2723.         hRes = AtlModuleRegisterTypeLib(pM, 0);
  2724.     return hRes;
  2725. }
  2726.  
  2727. ATLAPI AtlModuleUnregisterServer(_ATL_MODULE* pM, const CLSID* pCLSID)
  2728. {
  2729.     _ASSERTE(pM != NULL);
  2730.     if (pM == NULL)
  2731.         return E_INVALIDARG;
  2732.     _ASSERTE(pM->m_hInst != NULL);
  2733.     _ASSERTE(pM->m_pObjMap != NULL);
  2734.     _ATL_OBJMAP_ENTRY* pEntry = pM->m_pObjMap;
  2735.     for (;pEntry->pclsid != NULL; pEntry++)
  2736.     {
  2737.         if (pCLSID == NULL)
  2738.         {
  2739.             if (pEntry->pfnGetObjectDescription() != NULL)
  2740.                 continue;
  2741.         }
  2742.         else
  2743.         {
  2744.             if (!IsEqualGUID(*pCLSID, *pEntry->pclsid))
  2745.                 continue;
  2746.         }
  2747.         pEntry->pfnUpdateRegistry(FALSE); //unregister
  2748.     }
  2749.     return S_OK;
  2750. }
  2751.  
  2752. ATLAPI AtlModuleUpdateRegistryFromResourceD(_ATL_MODULE* pM, LPCOLESTR lpszRes,
  2753.     BOOL bRegister, struct _ATL_REGMAP_ENTRY* pMapEntries, IRegistrar* pReg)
  2754. {
  2755.     USES_CONVERSION;
  2756.     _ASSERTE(pM != NULL);
  2757.     HRESULT hRes = S_OK;
  2758.     CComPtr<IRegistrar> p;
  2759.     if (pReg != NULL)
  2760.         p = pReg;
  2761.     else
  2762.     {
  2763.         hRes = CoCreateInstance(CLSID_Registrar, NULL,
  2764.             CLSCTX_INPROC_SERVER, IID_IRegistrar, (void**)&p);
  2765.     }
  2766.     if (SUCCEEDED(hRes))
  2767.     {
  2768.         TCHAR szModule[_MAX_PATH];
  2769.         GetModuleFileName(pM->m_hInst, szModule, _MAX_PATH);
  2770.         p->AddReplacement(OLESTR("Module"), T2OLE(szModule));
  2771.  
  2772.         if (NULL != pMapEntries)
  2773.         {
  2774.             while (NULL != pMapEntries->szKey)
  2775.             {
  2776.                 _ASSERTE(NULL != pMapEntries->szData);
  2777.                 p->AddReplacement((LPOLESTR)pMapEntries->szKey, (LPOLESTR)pMapEntries->szData);
  2778.                 pMapEntries++;
  2779.             }
  2780.         }
  2781.         LPCOLESTR szType = OLESTR("REGISTRY");
  2782.         GetModuleFileName(pM->m_hInstResource, szModule, _MAX_PATH);
  2783.         LPOLESTR pszModule = T2OLE(szModule);
  2784.         if (HIWORD(lpszRes)==0)
  2785.         {
  2786.             if (bRegister)
  2787.                 hRes = p->ResourceRegister(pszModule, ((UINT)LOWORD((DWORD)lpszRes)), szType);
  2788.             else
  2789.                 hRes = p->ResourceUnregister(pszModule, ((UINT)LOWORD((DWORD)lpszRes)), szType);
  2790.         }
  2791.         else
  2792.         {
  2793.             if (bRegister)
  2794.                 hRes = p->ResourceRegisterSz(pszModule, lpszRes, szType);
  2795.             else
  2796.                 hRes = p->ResourceUnregisterSz(pszModule, lpszRes, szType);
  2797.         }
  2798.  
  2799.     }
  2800.     return hRes;
  2801. }
  2802.  
  2803. /////////////////////////////////////////////////////////////////////////////
  2804. // TypeLib Support
  2805.  
  2806. ATLAPI AtlModuleRegisterTypeLib(_ATL_MODULE* pM, LPCOLESTR lpszIndex)
  2807. {
  2808.     _ASSERTE(pM != NULL);
  2809.     USES_CONVERSION;
  2810.     _ASSERTE(pM->m_hInstTypeLib != NULL);
  2811.     TCHAR szModule[_MAX_PATH+10];
  2812.     OLECHAR szDir[_MAX_PATH];
  2813.     GetModuleFileName(pM->m_hInstTypeLib, szModule, _MAX_PATH);
  2814.     if (lpszIndex != NULL)
  2815.         lstrcat(szModule, OLE2CT(lpszIndex));
  2816.     ITypeLib* pTypeLib;
  2817.     LPOLESTR lpszModule = T2OLE(szModule);
  2818.     HRESULT hr = LoadTypeLib(lpszModule, &pTypeLib);
  2819.     if (!SUCCEEDED(hr))
  2820.     {
  2821.         // typelib not in module, try <module>.tlb instead
  2822.         LPTSTR lpszExt = NULL;
  2823.         LPTSTR lpsz;
  2824.         for (lpsz = szModule; *lpsz != NULL; lpsz = CharNext(lpsz))
  2825.         {
  2826.             if (*lpsz == _T('.'))
  2827.                 lpszExt = lpsz;
  2828.         }
  2829.         if (lpszExt == NULL)
  2830.             lpszExt = lpsz;
  2831.         lstrcpy(lpszExt, _T(".tlb"));
  2832.         lpszModule = T2OLE(szModule);
  2833.         hr = LoadTypeLib(lpszModule, &pTypeLib);
  2834.     }
  2835.     if (SUCCEEDED(hr))
  2836.     {
  2837.         ocscpy(szDir, lpszModule);
  2838.         szDir[AtlGetDirLen(szDir)] = 0;
  2839.         hr = ::RegisterTypeLib(pTypeLib, lpszModule, szDir);
  2840.     }
  2841.     if (pTypeLib != NULL)
  2842.         pTypeLib->Release();
  2843.     return hr;
  2844. }
  2845.  
  2846. #ifndef ATL_NO_NAMESPACE
  2847. #ifndef _ATL_DLL_IMPL
  2848. }; //namespace ATL
  2849. #endif
  2850. #endif
  2851.  
  2852. #endif //!_ATL_DLL
  2853.